Initial commit
This commit is contained in:
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
BIN
src/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/config.cpython-312.pyc
Normal file
BIN
src/__pycache__/config.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/main.cpython-312.pyc
Normal file
BIN
src/__pycache__/main.cpython-312.pyc
Normal file
Binary file not shown.
5
src/adapters/__init__.py
Normal file
5
src/adapters/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import ModelAdapter, ModelResponse, StreamChunk
|
||||
from .claude_adapter import ClaudeAdapter
|
||||
from .openai_adapter import OpenAIAdapter
|
||||
|
||||
__all__ = ["ModelAdapter", "ModelResponse", "StreamChunk", "ClaudeAdapter", "OpenAIAdapter"]
|
||||
BIN
src/adapters/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/base.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/claude_adapter.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/claude_adapter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/openai_adapter.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/openai_adapter.cpython-312.pyc
Normal file
Binary file not shown.
73
src/adapters/base.py
Normal file
73
src/adapters/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Model adapter interface — extensible for any LLM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""A single chunk from a streaming model response."""
|
||||
|
||||
delta: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_arguments: str = ""
|
||||
finish_reason: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Complete (non-streaming) model response."""
|
||||
|
||||
content: str = ""
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
finish_reason: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
raw: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Per-call configuration."""
|
||||
|
||||
model_id: str = ""
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.3
|
||||
stop_sequences: list[str] = field(default_factory=list)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""Abstract interface for LLM providers.
|
||||
|
||||
Implementors must provide both streaming and non-streaming methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""Stream model response chunks."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
"""Get a complete model response (non-streaming)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
"""Estimate token count for the given text."""
|
||||
...
|
||||
201
src/adapters/claude_adapter.py
Normal file
201
src/adapters/claude_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Claude/Anthropic model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import anthropic
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""Adapter for the Anthropic Claude API."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key or settings.anthropic_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
# Separate system message
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if block.type == "tool_use":
|
||||
current_tool_id = block.id
|
||||
current_tool_name = block.name
|
||||
accumulated_args = ""
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
yield StreamChunk(delta=delta.text)
|
||||
elif delta.type == "input_json_delta":
|
||||
accumulated_args += delta.partial_json
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=delta.partial_json,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_stop":
|
||||
if current_tool_id and accumulated_args:
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=accumulated_args,
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
continue
|
||||
|
||||
if event.type == "message_delta":
|
||||
yield StreamChunk(
|
||||
finish_reason=event.delta.stop_reason or "",
|
||||
usage={
|
||||
"output_tokens": getattr(
|
||||
event.usage, "output_tokens", 0
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.messages.create(**kwargs)
|
||||
|
||||
content = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content += block.text
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"arguments": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=response.stop_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to Anthropic tool format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("input_schema", tool.get("parameters", {"type": "object"})),
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
197
src/adapters/openai_adapter.py
Normal file
197
src/adapters/openai_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""OpenAI model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""Adapter for the OpenAI API (GPT-4o, o1, etc.)."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or settings.openai_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# Text content
|
||||
if delta and delta.content:
|
||||
yield StreamChunk(delta=delta.content)
|
||||
|
||||
# Tool calls
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc.id or "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
yield StreamChunk(
|
||||
tool_call_id=tc.id or tool_calls_acc[idx]["id"],
|
||||
tool_name=tc.function.name,
|
||||
)
|
||||
if tc.function and tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
yield StreamChunk(
|
||||
tool_call_id=tool_calls_acc[idx]["id"],
|
||||
tool_name=tool_calls_acc[idx]["name"],
|
||||
tool_arguments=tc.function.arguments,
|
||||
)
|
||||
|
||||
# Finish
|
||||
if choice.finish_reason:
|
||||
if choice.finish_reason == "tool_calls":
|
||||
for acc in tool_calls_acc.values():
|
||||
yield StreamChunk(
|
||||
tool_call_id=acc["id"],
|
||||
tool_name=acc["name"],
|
||||
tool_arguments=acc["arguments"],
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
else:
|
||||
yield StreamChunk(
|
||||
finish_reason="end_turn"
|
||||
if choice.finish_reason == "stop"
|
||||
else choice.finish_reason,
|
||||
usage={
|
||||
"output_tokens": chunk.usage.completion_tokens
|
||||
if chunk.usage
|
||||
else 0
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.chat.completions.create(**kwargs)
|
||||
choice = response.choices[0]
|
||||
|
||||
content = choice.message.content or ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
if choice.message.tool_calls:
|
||||
for tc in choice.message.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments)
|
||||
if tc.function.arguments
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"output_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to OpenAI function calling format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get(
|
||||
"input_schema", tool.get("parameters", {"type": "object"})
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
3
src/api/__init__.py
Normal file
3
src/api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
BIN
src/api/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/api/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/api/__pycache__/routes.cpython-312.pyc
Normal file
BIN
src/api/__pycache__/routes.cpython-312.pyc
Normal file
Binary file not shown.
375
src/api/routes.py
Normal file
375
src/api/routes.py
Normal file
@@ -0,0 +1,375 @@
|
||||
"""REST API endpoints for the agentic microservice."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import pathlib
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models.context import MemoryDocument, MemoryType
|
||||
from ..models.session import SessionState, SessionStatus
|
||||
from ..streaming.sse import EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Request / Response schemas
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class CreateSessionRequest(BaseModel):
|
||||
project_profile: dict[str, Any] = Field(default_factory=dict)
|
||||
immutable_rules: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CreateSessionResponse(BaseModel):
|
||||
session_id: str
|
||||
status: str
|
||||
|
||||
|
||||
class SendMessageRequest(BaseModel):
|
||||
message: str
|
||||
stream: bool = False
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
session_id: str
|
||||
status: str
|
||||
turn_count: int
|
||||
current_task: dict[str, Any] | None = None
|
||||
completed_tasks: list[str] = Field(default_factory=list)
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Dependency helpers (set by main.py at startup)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_deps: dict[str, Any] = {}
|
||||
|
||||
|
||||
def set_dependencies(
|
||||
storage: Any,
|
||||
orchestrator: Any,
|
||||
sse_emitter: Any,
|
||||
context_engine: Any = None,
|
||||
memory_store: Any = None,
|
||||
) -> None:
|
||||
_deps["storage"] = storage
|
||||
_deps["orchestrator"] = orchestrator
|
||||
_deps["sse"] = sse_emitter
|
||||
if context_engine:
|
||||
_deps["context_engine"] = context_engine
|
||||
if memory_store:
|
||||
_deps["memory_store"] = memory_store
|
||||
|
||||
|
||||
def _get_storage():
|
||||
return _deps["storage"]
|
||||
|
||||
|
||||
def _get_orchestrator():
|
||||
return _deps["orchestrator"]
|
||||
|
||||
|
||||
def _get_sse():
|
||||
return _deps["sse"]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /sessions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions", response_model=CreateSessionResponse, status_code=201)
|
||||
async def create_session(body: CreateSessionRequest) -> CreateSessionResponse:
|
||||
storage = _get_storage()
|
||||
session = SessionState(
|
||||
project_profile=body.project_profile,
|
||||
immutable_rules=body.immutable_rules,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
await storage.create_session(session)
|
||||
|
||||
sse = _get_sse()
|
||||
await sse.emit(
|
||||
EventType.SESSION_CREATED,
|
||||
{"session_id": session.session_id},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
logger.info("Session created: %s", session.session_id)
|
||||
return CreateSessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status.value,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /sessions/{id}/messages
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.post("/sessions/{session_id}/messages")
|
||||
async def send_message(
|
||||
session_id: str, body: SendMessageRequest
|
||||
) -> dict[str, Any]:
|
||||
storage = _get_storage()
|
||||
session = await storage.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
orchestrator = _get_orchestrator()
|
||||
|
||||
if body.stream:
|
||||
asyncio.create_task(_execute_and_persist(orchestrator, storage, session, body.message))
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"status": "executing",
|
||||
"stream_url": f"/sessions/{session_id}/stream",
|
||||
}
|
||||
|
||||
result = await _execute_and_persist(orchestrator, storage, session, body.message)
|
||||
return result
|
||||
|
||||
|
||||
async def _execute_and_persist(orchestrator, storage, session, message) -> dict[str, Any]:
|
||||
# Acquire exclusive lock — prevents concurrent execution on same session
|
||||
async with storage.session_lock(session.session_id) as acquired:
|
||||
if not acquired:
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"content": "Error: session is busy — another request is executing",
|
||||
"status": "busy",
|
||||
}
|
||||
|
||||
try:
|
||||
result = await orchestrator.process_message(session, message)
|
||||
return result
|
||||
except Exception as e:
|
||||
session.status = SessionStatus.ERROR
|
||||
logger.exception("Execution failed for session %s", session.session_id)
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"content": f"Error: {e}",
|
||||
"status": "error",
|
||||
}
|
||||
finally:
|
||||
try:
|
||||
await storage.update_session(session)
|
||||
except Exception as e:
|
||||
logger.error("Failed to persist session state: %s", e)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /sessions/{id}/stream
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/stream")
|
||||
async def stream_session(session_id: str) -> StreamingResponse:
|
||||
storage = _get_storage()
|
||||
session = await storage.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
sse = _get_sse()
|
||||
|
||||
return StreamingResponse(
|
||||
sse.subscribe(session_id),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /sessions/{id}
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}", response_model=SessionResponse)
|
||||
async def get_session(session_id: str) -> SessionResponse:
|
||||
storage = _get_storage()
|
||||
session = await storage.get_session(session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
return SessionResponse(
|
||||
session_id=session.session_id,
|
||||
status=session.status.value,
|
||||
turn_count=session.turn_count,
|
||||
current_task=session.current_task.model_dump() if session.current_task else None,
|
||||
completed_tasks=session.completed_tasks,
|
||||
created_at=session.created_at.isoformat(),
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# DELETE /sessions/{id}
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.delete("/sessions/{session_id}")
|
||||
async def delete_session(session_id: str) -> dict[str, str]:
|
||||
storage = _get_storage()
|
||||
deleted = await storage.delete_session(session_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
sse = _get_sse()
|
||||
sse.cleanup_session(session_id)
|
||||
|
||||
return {"status": "deleted", "session_id": session_id}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /sessions/{id}/events
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/events")
|
||||
async def get_session_events(session_id: str) -> list[dict[str, Any]]:
|
||||
sse = _get_sse()
|
||||
return await sse.get_history(session_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /sessions/{id}/context-debug
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@router.get("/sessions/{session_id}/context-debug")
|
||||
async def get_context_debug(session_id: str) -> dict[str, Any]:
|
||||
"""Returns the full context engine debug history for a session.
|
||||
|
||||
Shows exactly what each agent received: sections, token counts,
|
||||
priorities, compaction status, and content previews.
|
||||
"""
|
||||
ctx_engine = _deps.get("context_engine")
|
||||
if not ctx_engine:
|
||||
raise HTTPException(status_code=501, detail="Context engine not available")
|
||||
|
||||
history = ctx_engine.get_debug_history(session_id)
|
||||
last = ctx_engine.get_last_context_debug(session_id)
|
||||
|
||||
return {
|
||||
"session_id": session_id,
|
||||
"total_builds": len(history),
|
||||
"last_build": last,
|
||||
"history": history,
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Knowledge Base
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
class LoadKnowledgeRequest(BaseModel):
|
||||
docs_path: str = "docs"
|
||||
|
||||
|
||||
@router.post("/knowledge/load")
|
||||
async def load_knowledge(body: LoadKnowledgeRequest) -> dict[str, Any]:
|
||||
"""Load markdown docs from a directory into the knowledge base."""
|
||||
memory = _deps.get("memory_store")
|
||||
if not memory:
|
||||
raise HTTPException(status_code=501, detail="Memory store not available")
|
||||
|
||||
docs_dir = pathlib.Path(body.docs_path)
|
||||
if not docs_dir.is_absolute():
|
||||
# Resolve relative to project root
|
||||
docs_dir = pathlib.Path(__file__).resolve().parent.parent.parent / body.docs_path
|
||||
|
||||
if not docs_dir.is_dir():
|
||||
raise HTTPException(status_code=400, detail=f"Directory not found: {docs_dir}")
|
||||
|
||||
loaded = []
|
||||
for md_file in sorted(docs_dir.glob("*.md")):
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
doc_id = md_file.stem
|
||||
|
||||
# Build a summary from the first ~500 chars
|
||||
lines = content.strip().splitlines()
|
||||
title = lines[0].lstrip("#").strip() if lines else doc_id
|
||||
summary_lines = []
|
||||
for line in lines[:30]:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
summary_lines.append(line)
|
||||
if len(" ".join(summary_lines)) > 500:
|
||||
break
|
||||
summary = " ".join(summary_lines)[:500]
|
||||
|
||||
# Extract tags from headings
|
||||
tags = []
|
||||
for line in lines:
|
||||
if line.startswith("## "):
|
||||
tags.append(line.lstrip("#").strip().lower()[:30])
|
||||
|
||||
doc = MemoryDocument(
|
||||
memory_id=doc_id,
|
||||
memory_type=MemoryType.DOCUMENT,
|
||||
namespace="knowledge",
|
||||
title=title,
|
||||
content=content,
|
||||
summary=summary,
|
||||
tags=tags[:10],
|
||||
)
|
||||
await memory.store_document(doc)
|
||||
loaded.append({
|
||||
"id": doc_id,
|
||||
"title": title,
|
||||
"chars": len(content),
|
||||
"tags": tags[:5],
|
||||
})
|
||||
|
||||
logger.info("Loaded %d knowledge documents from %s", len(loaded), docs_dir)
|
||||
return {
|
||||
"status": "loaded",
|
||||
"count": len(loaded),
|
||||
"documents": loaded,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/knowledge")
|
||||
async def list_knowledge() -> dict[str, Any]:
|
||||
"""List all documents in the knowledge base."""
|
||||
memory = _deps.get("memory_store")
|
||||
if not memory:
|
||||
raise HTTPException(status_code=501, detail="Memory store not available")
|
||||
|
||||
docs = await memory.list_documents(namespace="knowledge")
|
||||
return {
|
||||
"count": len(docs),
|
||||
"documents": [
|
||||
{
|
||||
"id": d.memory_id,
|
||||
"title": d.title,
|
||||
"chars": len(d.content),
|
||||
"summary": d.summary[:200],
|
||||
"tags": d.tags,
|
||||
"updated_at": d.updated_at.isoformat(),
|
||||
}
|
||||
for d in docs
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/knowledge/{doc_id}")
|
||||
async def delete_knowledge(doc_id: str) -> dict[str, str]:
|
||||
"""Remove a document from the knowledge base."""
|
||||
memory = _deps.get("memory_store")
|
||||
if not memory:
|
||||
raise HTTPException(status_code=501, detail="Memory store not available")
|
||||
|
||||
deleted = await memory.delete_document(doc_id, namespace="knowledge")
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
return {"status": "deleted", "id": doc_id}
|
||||
61
src/config.py
Normal file
61
src/config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Application configuration via environment variables."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# --- Service ---
|
||||
service_name: str = "agentic-microservice"
|
||||
service_version: str = "1.0.0"
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
|
||||
# --- Redis ---
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_db: int = 0
|
||||
redis_password: str = ""
|
||||
redis_key_prefix: str = "agentic"
|
||||
session_ttl_seconds: int = 86400 # 24h
|
||||
|
||||
@property
|
||||
def redis_url(self) -> str:
|
||||
auth = f":{self.redis_password}@" if self.redis_password else ""
|
||||
return f"redis://{auth}{self.redis_host}:{self.redis_port}/{self.redis_db}"
|
||||
|
||||
# --- Model providers ---
|
||||
anthropic_api_key: str = ""
|
||||
openai_api_key: str = ""
|
||||
default_model_provider: str = "claude"
|
||||
default_model_id: str = "claude-sonnet-4-20250514"
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.3
|
||||
|
||||
# --- Context engine ---
|
||||
context_max_tokens: int = 120_000
|
||||
compaction_threshold_tokens: int = 80_000
|
||||
artifact_summary_max_chars: int = 2000
|
||||
working_context_max_items: int = 20
|
||||
|
||||
# --- MCP ---
|
||||
mcp_server_command: str = ""
|
||||
mcp_server_args: list[str] = Field(default_factory=list)
|
||||
mcp_timeout_seconds: float = 30.0
|
||||
mcp_startup_timeout_seconds: float = 10.0
|
||||
|
||||
# --- Orchestrator ---
|
||||
max_execution_steps: int = 25
|
||||
subagent_max_steps: int = 10
|
||||
max_execution_timeout_seconds: float = 300.0 # 5 min global timeout
|
||||
|
||||
# --- SSE ---
|
||||
sse_keepalive_seconds: float = 15.0
|
||||
|
||||
model_config = {"env_prefix": "AGENTIC_", "env_file": ".env", "extra": "ignore"}
|
||||
|
||||
|
||||
settings = Settings()
|
||||
4
src/context/__init__.py
Normal file
4
src/context/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .engine import ContextEngine
|
||||
from .compactor import ContextCompactor
|
||||
|
||||
__all__ = ["ContextEngine", "ContextCompactor"]
|
||||
BIN
src/context/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/context/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/context/__pycache__/compactor.cpython-312.pyc
Normal file
BIN
src/context/__pycache__/compactor.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/context/__pycache__/engine.cpython-312.pyc
Normal file
BIN
src/context/__pycache__/engine.cpython-312.pyc
Normal file
Binary file not shown.
229
src/context/compactor.py
Normal file
229
src/context/compactor.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Context compaction: extract facts, remove redundancy, maintain constraints.
|
||||
|
||||
The compactor is responsible for keeping the context within token budget
|
||||
while preserving the most important information.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from ..models.artifacts import ArtifactSummary
|
||||
from ..models.context import ContextSection, ContextSectionType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Token counting with tiktoken ---
|
||||
try:
|
||||
import tiktoken
|
||||
_encoding = tiktoken.get_encoding("cl100k_base") # Works for GPT-4o and Claude
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Accurate token count using tiktoken (cl100k_base encoding)."""
|
||||
if not text:
|
||||
return 0
|
||||
return len(_encoding.encode(text, disallowed_special=()))
|
||||
|
||||
logger.info("Using tiktoken for accurate token counting")
|
||||
except ImportError:
|
||||
def estimate_tokens(text: str) -> int:
|
||||
"""Fallback: ~4 chars per token."""
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
logger.warning("tiktoken not installed — using approximate token counting")
|
||||
|
||||
|
||||
class ContextCompactor:
|
||||
"""Compacts context sections to fit within token budgets."""
|
||||
|
||||
def __init__(self, max_tokens: int = 120_000) -> None:
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def compact_sections(
|
||||
self, sections: list[ContextSection]
|
||||
) -> list[ContextSection]:
|
||||
"""Remove redundancy and trim low-priority sections to fit budget."""
|
||||
# 1. Deduplicate identical content across sections
|
||||
sections = self._deduplicate(sections)
|
||||
|
||||
# 2. Estimate tokens per section
|
||||
for s in sections:
|
||||
s.token_estimate = estimate_tokens(s.content)
|
||||
|
||||
total = sum(s.token_estimate for s in sections)
|
||||
if total <= self.max_tokens:
|
||||
return sections
|
||||
|
||||
# 3. Sort by priority (highest first) — immutable_rules never trimmed
|
||||
sections.sort(key=lambda s: s.priority, reverse=True)
|
||||
|
||||
# 4. Progressively trim lowest-priority sections
|
||||
while total > self.max_tokens and sections:
|
||||
lowest = sections[-1]
|
||||
if lowest.section_type == ContextSectionType.IMMUTABLE_RULES:
|
||||
break # Never trim rules
|
||||
# Try to compact the section first
|
||||
compacted = self._compact_text(lowest.content)
|
||||
new_estimate = estimate_tokens(compacted)
|
||||
saved = lowest.token_estimate - new_estimate
|
||||
if saved > 0:
|
||||
lowest.content = compacted
|
||||
lowest.token_estimate = new_estimate
|
||||
total -= saved
|
||||
else:
|
||||
# Remove the section entirely
|
||||
total -= lowest.token_estimate
|
||||
sections.pop()
|
||||
|
||||
return sections
|
||||
|
||||
def summarize_tool_output(
|
||||
self,
|
||||
tool_name: str,
|
||||
raw_output: str,
|
||||
session_id: str,
|
||||
task_id: str,
|
||||
) -> ArtifactSummary:
|
||||
"""Summarise raw tool output into an ArtifactSummary.
|
||||
|
||||
The raw output is NEVER passed through to the model context.
|
||||
"""
|
||||
facts = self._extract_facts(raw_output)
|
||||
summary = self._build_summary(tool_name, raw_output, facts)
|
||||
|
||||
artifact_type = self._infer_artifact_type(tool_name)
|
||||
artifact_id = hashlib.sha256(
|
||||
f"{session_id}:{task_id}:{tool_name}:{raw_output[:200]}".encode()
|
||||
).hexdigest()[:16]
|
||||
|
||||
return ArtifactSummary(
|
||||
artifact_id=artifact_id,
|
||||
session_id=session_id,
|
||||
task_id=task_id,
|
||||
artifact_type=artifact_type,
|
||||
title=f"Output of {tool_name}",
|
||||
summary=summary,
|
||||
facts=facts,
|
||||
source_tool=tool_name,
|
||||
char_count=len(raw_output),
|
||||
)
|
||||
|
||||
def compact_artifact_summaries(
|
||||
self, summaries: list[ArtifactSummary], max_chars: int = 2000
|
||||
) -> str:
|
||||
"""Merge multiple artifact summaries into a single compact block."""
|
||||
if not summaries:
|
||||
return ""
|
||||
|
||||
lines: list[str] = ["## Artifacts"]
|
||||
budget = max_chars - 20
|
||||
for art in summaries:
|
||||
entry = f"- [{art.artifact_type}] {art.title}: {art.summary}"
|
||||
if art.facts:
|
||||
entry += " | Facts: " + "; ".join(art.facts[:3])
|
||||
if len(entry) > budget:
|
||||
entry = entry[:budget] + "…"
|
||||
lines.append(entry)
|
||||
budget -= len(entry)
|
||||
if budget <= 0:
|
||||
lines.append(f" … and {len(summaries) - len(lines) + 1} more artifacts")
|
||||
break
|
||||
return "\n".join(lines)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _deduplicate(
|
||||
self, sections: list[ContextSection]
|
||||
) -> list[ContextSection]:
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
unique: list[ContextSection] = []
|
||||
for s in sections:
|
||||
h = hashlib.md5(s.content.encode()).hexdigest()
|
||||
if h not in seen:
|
||||
seen.add(h)
|
||||
unique.append(s)
|
||||
return unique
|
||||
|
||||
def _compact_text(self, text: str) -> str:
|
||||
"""Aggressively compact text: remove blank lines, collapse whitespace."""
|
||||
lines = text.splitlines()
|
||||
# Remove empty or whitespace-only lines
|
||||
lines = [l.rstrip() for l in lines if l.strip()]
|
||||
# Collapse consecutive duplicate lines
|
||||
compacted: list[str] = []
|
||||
for line in lines:
|
||||
if not compacted or line != compacted[-1]:
|
||||
compacted.append(line)
|
||||
return "\n".join(compacted)
|
||||
|
||||
def _extract_facts(self, raw_output: str) -> list[str]:
|
||||
"""Extract short factual claims from tool output."""
|
||||
facts: list[str] = []
|
||||
lines = raw_output.strip().splitlines()
|
||||
|
||||
for line in lines[:100]: # Limit scan depth
|
||||
line = line.strip()
|
||||
if not line or len(line) < 10:
|
||||
continue
|
||||
# Lines that look like key-value facts
|
||||
if re.match(r"^[\w\s]+:\s+.+", line) and len(line) < 200:
|
||||
facts.append(line)
|
||||
# Lines starting with status indicators
|
||||
elif re.match(r"^(✓|✗|PASS|FAIL|ERROR|OK|INFO|WARNING)", line):
|
||||
facts.append(line)
|
||||
# Lines that contain file paths with results
|
||||
elif re.match(r"^[\w/\\.]+\s*[:\-]\s*.+", line) and len(line) < 200:
|
||||
facts.append(line)
|
||||
|
||||
# Deduplicate and limit
|
||||
seen: set[str] = set()
|
||||
unique: list[str] = []
|
||||
for f in facts:
|
||||
if f not in seen:
|
||||
seen.add(f)
|
||||
unique.append(f)
|
||||
return unique[:15]
|
||||
|
||||
def _build_summary(
|
||||
self, tool_name: str, raw_output: str, facts: list[str]
|
||||
) -> str:
|
||||
"""Build a concise summary from tool output."""
|
||||
lines = raw_output.strip().splitlines()
|
||||
total_lines = len(lines)
|
||||
char_count = len(raw_output)
|
||||
|
||||
parts = [f"Tool '{tool_name}' returned {total_lines} lines ({char_count} chars)."]
|
||||
|
||||
if facts:
|
||||
parts.append(f"Key findings: {'; '.join(facts[:5])}")
|
||||
|
||||
# Include first and last meaningful lines as bookends
|
||||
meaningful = [l.strip() for l in lines if l.strip()]
|
||||
if meaningful:
|
||||
parts.append(f"First: {meaningful[0][:120]}")
|
||||
if len(meaningful) > 1:
|
||||
parts.append(f"Last: {meaningful[-1][:120]}")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def _infer_artifact_type(self, tool_name: str) -> str:
|
||||
tool_lower = tool_name.lower()
|
||||
if any(k in tool_lower for k in ("read", "file", "code", "write", "edit")):
|
||||
return "code"
|
||||
if any(k in tool_lower for k in ("test", "check", "lint", "validate")):
|
||||
return "test_result"
|
||||
if any(k in tool_lower for k in ("search", "find", "grep", "glob")):
|
||||
return "analysis"
|
||||
if any(k in tool_lower for k in ("plan", "design", "architect")):
|
||||
return "plan"
|
||||
return "general"
|
||||
551
src/context/engine.py
Normal file
551
src/context/engine.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""Context Engine — the central intelligence of the system.
|
||||
|
||||
Builds structured prompts from session state. Never includes raw tool
|
||||
outputs. Handles compaction, artifact summarization, and selective
|
||||
rehydration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from ..config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ..models.agent import AgentProfile
|
||||
from ..models.artifacts import ArtifactSummary
|
||||
from ..memory.store import MemoryStore
|
||||
from ..models.context import (
|
||||
ContextPackage,
|
||||
ContextSection,
|
||||
ContextSectionType,
|
||||
MemoryDocument,
|
||||
MemoryType,
|
||||
)
|
||||
from ..models.session import SessionState, TaskState
|
||||
from .compactor import ContextCompactor, estimate_tokens
|
||||
|
||||
|
||||
class ContextEngine:
|
||||
"""Assembles the context package that gets sent to the model.
|
||||
|
||||
The engine enforces a strict contract:
|
||||
- Raw tool outputs NEVER appear in the context
|
||||
- Each section has a priority for compaction
|
||||
- Immutable rules are always included in full
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compactor: ContextCompactor | None = None,
|
||||
memory_store: MemoryStore | None = None,
|
||||
) -> None:
|
||||
self.compactor = compactor or ContextCompactor(
|
||||
max_tokens=settings.context_max_tokens
|
||||
)
|
||||
self.memory = memory_store
|
||||
# Debug history: last N context builds per session
|
||||
self._history: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
self._max_history = 20
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public — build context for a model call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def build_context(
|
||||
self,
|
||||
session: SessionState,
|
||||
agent: AgentProfile,
|
||||
artifacts: list[ArtifactSummary] | None = None,
|
||||
working_items: list[dict[str, Any]] | None = None,
|
||||
extra_instructions: str = "",
|
||||
) -> ContextPackage:
|
||||
"""Build a full ContextPackage for the given agent and session."""
|
||||
|
||||
sections: list[ContextSection] = []
|
||||
allowed = set(agent.context_sections)
|
||||
|
||||
# 1. Immutable rules — highest priority, never trimmed
|
||||
if "immutable_rules" in allowed:
|
||||
sections.append(self._build_immutable_rules(session, agent))
|
||||
|
||||
# 2. Project profile
|
||||
if "project_profile" in allowed:
|
||||
sections.append(self._build_project_profile(session))
|
||||
|
||||
# 3. Knowledge base — loaded from memory store
|
||||
if "knowledge_base" in allowed and self.memory:
|
||||
kb_section = await self._build_knowledge_base(session)
|
||||
if kb_section:
|
||||
sections.append(kb_section)
|
||||
|
||||
# 4. Task state
|
||||
if "task_state" in allowed and session.current_task:
|
||||
sections.append(self._build_task_state(session.current_task))
|
||||
|
||||
# 5. Artifact memory — summarised, never raw
|
||||
if "artifact_memory" in allowed and artifacts:
|
||||
sections.append(self._build_artifact_memory(artifacts))
|
||||
|
||||
# 6. Working context — recent relevant items
|
||||
if "working_context" in allowed:
|
||||
sections.append(
|
||||
self._build_working_context(working_items or [], extra_instructions)
|
||||
)
|
||||
|
||||
# Compact to fit budget
|
||||
sections = self.compactor.compact_sections(sections)
|
||||
|
||||
# Assemble system prompt from sections
|
||||
system_prompt = self._assemble_system_prompt(sections)
|
||||
|
||||
# Build messages (just user message — no chat history)
|
||||
messages = self._build_messages(session)
|
||||
|
||||
total_tokens = estimate_tokens(system_prompt) + sum(
|
||||
estimate_tokens(m.get("content", "")) for m in messages
|
||||
)
|
||||
|
||||
package = ContextPackage(
|
||||
sections=sections,
|
||||
system_prompt=system_prompt,
|
||||
messages=messages,
|
||||
total_token_estimate=total_tokens,
|
||||
)
|
||||
|
||||
# --- Debug: log and store context build ---
|
||||
section_summary = []
|
||||
for s in sections:
|
||||
section_summary.append({
|
||||
"type": s.section_type.value,
|
||||
"priority": s.priority,
|
||||
"tokens": s.token_estimate,
|
||||
"chars": len(s.content),
|
||||
"preview": s.content[:150].replace("\n", " "),
|
||||
})
|
||||
|
||||
debug_entry = {
|
||||
"timestamp": time.time(),
|
||||
"agent": agent.role.value,
|
||||
"agent_name": agent.name,
|
||||
"total_tokens": total_tokens,
|
||||
"sections": section_summary,
|
||||
"sections_count": len(sections),
|
||||
"compacted": len(sections) < len(allowed),
|
||||
"system_prompt_tokens": estimate_tokens(system_prompt),
|
||||
"user_message_preview": messages[0]["content"][:200] if messages else "",
|
||||
"artifacts_count": len(artifacts) if artifacts else 0,
|
||||
"working_items_count": len(working_items) if working_items else 0,
|
||||
}
|
||||
|
||||
history = self._history[session.session_id]
|
||||
history.append(debug_entry)
|
||||
if len(history) > self._max_history:
|
||||
self._history[session.session_id] = history[-self._max_history:]
|
||||
|
||||
logger.info(
|
||||
"Context built for [%s/%s] — %d sections, ~%d tokens, artifacts=%d, working_items=%d",
|
||||
session.session_id[:8],
|
||||
agent.role.value,
|
||||
len(sections),
|
||||
total_tokens,
|
||||
len(artifacts) if artifacts else 0,
|
||||
len(working_items) if working_items else 0,
|
||||
)
|
||||
for s in section_summary:
|
||||
logger.debug(
|
||||
" Section [%s] prio=%d tokens=%d chars=%d",
|
||||
s["type"], s["priority"], s["tokens"], s["chars"],
|
||||
)
|
||||
|
||||
return package
|
||||
|
||||
def get_debug_history(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Return the context build history for a session."""
|
||||
return list(self._history.get(session_id, []))
|
||||
|
||||
def get_last_context_debug(self, session_id: str) -> dict[str, Any] | None:
|
||||
"""Return the most recent context build for a session."""
|
||||
history = self._history.get(session_id, [])
|
||||
return history[-1] if history else None
|
||||
|
||||
def rehydrate_artifact(
|
||||
self,
|
||||
artifact: ArtifactSummary,
|
||||
full_content: str,
|
||||
) -> ContextSection:
|
||||
"""Selectively rehydrate an artifact into a full context section.
|
||||
|
||||
Used when the agent explicitly needs the full content of a
|
||||
specific artifact (e.g. reviewing generated code).
|
||||
"""
|
||||
content = (
|
||||
f"## Rehydrated Artifact: {artifact.title}\n"
|
||||
f"Type: {artifact.artifact_type} | Source: {artifact.source_tool}\n"
|
||||
f"---\n{full_content}\n---"
|
||||
)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.WORKING_CONTEXT,
|
||||
content=content,
|
||||
priority=5,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
def summarize_tool_output(
|
||||
self,
|
||||
tool_name: str,
|
||||
raw_output: str,
|
||||
session_id: str,
|
||||
task_id: str,
|
||||
) -> ArtifactSummary:
|
||||
"""Delegate to compactor — raw output never enters context."""
|
||||
return self.compactor.summarize_tool_output(
|
||||
tool_name=tool_name,
|
||||
raw_output=raw_output,
|
||||
session_id=session_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Section builders
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_immutable_rules(
|
||||
self, session: SessionState, agent: AgentProfile
|
||||
) -> ContextSection:
|
||||
parts = [
|
||||
"# System Rules (Immutable)",
|
||||
"",
|
||||
agent.system_prompt,
|
||||
"",
|
||||
]
|
||||
if session.immutable_rules:
|
||||
parts.append("## Session Rules")
|
||||
for rule in session.immutable_rules:
|
||||
parts.append(f"- {rule}")
|
||||
parts.extend(
|
||||
[
|
||||
"",
|
||||
"## Contrato de Contexto",
|
||||
"- NUNCA recibirás salidas crudas de herramientas en tu contexto.",
|
||||
"- Los resultados de herramientas se resumen como artefactos.",
|
||||
"- Solicita rehidratación si necesitas el contenido completo.",
|
||||
"- Mantén las respuestas enfocadas en el paso actual.",
|
||||
"- Responde SIEMPRE en español.",
|
||||
]
|
||||
)
|
||||
content = "\n".join(parts)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.IMMUTABLE_RULES,
|
||||
content=content,
|
||||
priority=100,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
def _build_project_profile(self, session: SessionState) -> ContextSection:
|
||||
if not session.project_profile:
|
||||
content = "# Project Profile\nNo project profile configured."
|
||||
else:
|
||||
lines = ["# Project Profile"]
|
||||
for key, value in session.project_profile.items():
|
||||
lines.append(f"- **{key}**: {value}")
|
||||
content = "\n".join(lines)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.PROJECT_PROFILE,
|
||||
content=content,
|
||||
priority=80,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
async def _build_knowledge_base(
|
||||
self, session: SessionState
|
||||
) -> ContextSection | None:
|
||||
"""Load relevant knowledge documents from the memory store.
|
||||
|
||||
Uses keyword matching against the task objective and step
|
||||
description to select only the most relevant docs.
|
||||
Max budget: ~15k tokens for knowledge.
|
||||
"""
|
||||
if not self.memory:
|
||||
return None
|
||||
|
||||
# Build search terms from current context
|
||||
search_terms = self._extract_search_terms(session)
|
||||
if not search_terms:
|
||||
# No task → load summaries of all docs (lightweight)
|
||||
return await self._build_knowledge_summaries_only()
|
||||
|
||||
all_docs: list[MemoryDocument] = []
|
||||
all_docs.extend(await self.memory.list_documents(
|
||||
namespace="knowledge",
|
||||
memory_type=MemoryType.DOCUMENT,
|
||||
))
|
||||
all_docs.extend(await self.memory.list_documents(
|
||||
namespace=f"knowledge:{session.session_id}",
|
||||
memory_type=MemoryType.DOCUMENT,
|
||||
))
|
||||
|
||||
if not all_docs:
|
||||
return None
|
||||
|
||||
# Score each doc by relevance
|
||||
scored = self._score_docs(all_docs, search_terms)
|
||||
|
||||
# Select top docs within token budget
|
||||
max_kb_tokens = 15_000
|
||||
selected: list[tuple[MemoryDocument, int]] = []
|
||||
token_budget = max_kb_tokens
|
||||
|
||||
for doc, score in scored:
|
||||
if score == 0:
|
||||
continue
|
||||
doc_tokens = estimate_tokens(doc.content)
|
||||
if doc_tokens > token_budget:
|
||||
# Include summary instead of full content
|
||||
summary_tokens = estimate_tokens(doc.summary or doc.title)
|
||||
if summary_tokens < token_budget:
|
||||
selected.append((doc, -1)) # -1 = summary only
|
||||
token_budget -= summary_tokens
|
||||
continue
|
||||
selected.append((doc, score))
|
||||
token_budget -= doc_tokens
|
||||
|
||||
if not selected:
|
||||
return await self._build_knowledge_summaries_only()
|
||||
|
||||
# Build section
|
||||
full_docs = [(d, s) for d, s in selected if s > 0]
|
||||
summary_docs = [(d, s) for d, s in selected if s == -1]
|
||||
|
||||
lines = [
|
||||
"# Knowledge Base",
|
||||
f"_{len(full_docs)} relevant doc(s) loaded, "
|
||||
f"{len(summary_docs)} summarized, "
|
||||
f"{len(all_docs) - len(selected)} filtered out_",
|
||||
"",
|
||||
]
|
||||
|
||||
for doc, _ in full_docs:
|
||||
lines.append(f"## {doc.title}")
|
||||
lines.append(doc.content)
|
||||
lines.append("")
|
||||
|
||||
if summary_docs:
|
||||
lines.append("## Other Available Docs (summaries)")
|
||||
for doc, _ in summary_docs:
|
||||
lines.append(f"- **{doc.title}**: {doc.summary[:200]}")
|
||||
lines.append("")
|
||||
|
||||
content = "\n".join(lines)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.KNOWLEDGE_BASE,
|
||||
content=content,
|
||||
priority=60,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
async def _build_knowledge_summaries_only(self) -> ContextSection | None:
|
||||
"""Lightweight: only doc titles and summaries (no full content)."""
|
||||
if not self.memory:
|
||||
return None
|
||||
docs = await self.memory.list_documents(
|
||||
namespace="knowledge", memory_type=MemoryType.DOCUMENT
|
||||
)
|
||||
if not docs:
|
||||
return None
|
||||
lines = ["# Knowledge Base (summaries)", ""]
|
||||
for doc in docs:
|
||||
lines.append(f"- **{doc.title}**: {doc.summary[:150]}")
|
||||
content = "\n".join(lines)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.KNOWLEDGE_BASE,
|
||||
content=content,
|
||||
priority=60,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
def _extract_search_terms(self, session: SessionState) -> set[str]:
|
||||
"""Extract keywords from the current task for doc matching."""
|
||||
terms: set[str] = set()
|
||||
if not session.current_task:
|
||||
return terms
|
||||
|
||||
text = session.current_task.objective.lower()
|
||||
step = session.current_task.current_step()
|
||||
if step:
|
||||
text += " " + step.description.lower()
|
||||
|
||||
# Split into words, filter short/common ones
|
||||
stopwords = {
|
||||
"de", "la", "el", "en", "un", "una", "los", "las", "del", "al",
|
||||
"por", "para", "con", "que", "como", "cómo", "qué", "es", "son",
|
||||
"se", "su", "más", "ya", "si", "no", "este", "esta", "esto",
|
||||
"the", "a", "an", "is", "are", "and", "or", "to", "in", "of",
|
||||
"for", "on", "with", "how", "what", "do", "does", "can",
|
||||
}
|
||||
for word in text.split():
|
||||
word = word.strip(".,;:!?¿¡()[]{}\"'`")
|
||||
if len(word) >= 3 and word not in stopwords:
|
||||
terms.add(word)
|
||||
|
||||
return terms
|
||||
|
||||
@staticmethod
|
||||
def _score_docs(
|
||||
docs: list[MemoryDocument], terms: set[str]
|
||||
) -> list[tuple[MemoryDocument, int]]:
|
||||
"""Score docs by keyword match against title, tags, and content."""
|
||||
scored: list[tuple[MemoryDocument, int]] = []
|
||||
|
||||
for doc in docs:
|
||||
score = 0
|
||||
title_lower = doc.title.lower()
|
||||
tags_lower = " ".join(doc.tags).lower()
|
||||
content_lower = doc.content[:2000].lower()
|
||||
|
||||
for term in terms:
|
||||
# Title match = high weight
|
||||
if term in title_lower:
|
||||
score += 10
|
||||
# Tag match = medium weight
|
||||
if term in tags_lower:
|
||||
score += 5
|
||||
# Content match = low weight
|
||||
if term in content_lower:
|
||||
score += 1
|
||||
|
||||
scored.append((doc, score))
|
||||
|
||||
scored.sort(key=lambda x: x[1], reverse=True)
|
||||
return scored
|
||||
|
||||
def _build_task_state(self, task: TaskState) -> ContextSection:
|
||||
lines = [
|
||||
"# Current Task",
|
||||
f"**Objective**: {task.objective}",
|
||||
f"**Status**: {task.status}",
|
||||
f"**Step**: {task.current_step_index + 1}/{len(task.plan)}",
|
||||
]
|
||||
|
||||
current = task.current_step()
|
||||
if current:
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Current Step",
|
||||
f"- Description: {current.description}",
|
||||
f"- Agent: {current.agent_role}",
|
||||
f"- Status: {current.status}",
|
||||
]
|
||||
)
|
||||
|
||||
if task.facts_extracted:
|
||||
lines.append("")
|
||||
lines.append("## Established Facts")
|
||||
for fact in task.facts_extracted[-10:]:
|
||||
lines.append(f"- {fact}")
|
||||
|
||||
if task.constraints:
|
||||
lines.append("")
|
||||
lines.append("## Constraints")
|
||||
for c in task.constraints:
|
||||
lines.append(f"- {c}")
|
||||
|
||||
# Show plan overview (compact)
|
||||
if task.plan:
|
||||
lines.append("")
|
||||
lines.append("## Plan Overview")
|
||||
for i, step in enumerate(task.plan):
|
||||
marker = "→" if i == task.current_step_index else "·"
|
||||
status_label = step.status.value
|
||||
lines.append(
|
||||
f" {marker} Step {i + 1} [{status_label}]: {step.description}"
|
||||
)
|
||||
|
||||
content = "\n".join(lines)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.TASK_STATE,
|
||||
content=content,
|
||||
priority=70,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
def _build_artifact_memory(
|
||||
self, artifacts: list[ArtifactSummary]
|
||||
) -> ContextSection:
|
||||
content = self.compactor.compact_artifact_summaries(
|
||||
artifacts, max_chars=settings.artifact_summary_max_chars
|
||||
)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.ARTIFACT_MEMORY,
|
||||
content=content,
|
||||
priority=50,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
def _build_working_context(
|
||||
self,
|
||||
items: list[dict[str, Any]],
|
||||
extra_instructions: str,
|
||||
) -> ContextSection:
|
||||
lines = ["# Working Context"]
|
||||
if extra_instructions:
|
||||
lines.append(f"\n{extra_instructions}")
|
||||
for item in items[: settings.working_context_max_items]:
|
||||
role = item.get("role", "info")
|
||||
content_val = item.get("content", "")
|
||||
lines.append(f"[{role}] {content_val}")
|
||||
content = "\n".join(lines)
|
||||
return ContextSection(
|
||||
section_type=ContextSectionType.WORKING_CONTEXT,
|
||||
content=content,
|
||||
priority=30,
|
||||
token_estimate=estimate_tokens(content),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Assembly
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _assemble_system_prompt(self, sections: list[ContextSection]) -> str:
|
||||
"""Combine sections into a single system prompt string."""
|
||||
parts: list[str] = []
|
||||
# Order: rules → profile → task → artifacts → working
|
||||
order = [
|
||||
ContextSectionType.IMMUTABLE_RULES,
|
||||
ContextSectionType.PROJECT_PROFILE,
|
||||
ContextSectionType.KNOWLEDGE_BASE,
|
||||
ContextSectionType.TASK_STATE,
|
||||
ContextSectionType.ARTIFACT_MEMORY,
|
||||
ContextSectionType.WORKING_CONTEXT,
|
||||
]
|
||||
section_map: dict[ContextSectionType, ContextSection] = {
|
||||
s.section_type: s for s in sections
|
||||
}
|
||||
for st in order:
|
||||
if st in section_map:
|
||||
parts.append(section_map[st].content)
|
||||
return "\n\n---\n\n".join(parts)
|
||||
|
||||
def _build_messages(self, session: SessionState) -> list[dict[str, Any]]:
|
||||
"""Build the messages array. We do NOT include chat history.
|
||||
|
||||
The user message is the current task objective (or a sentinel
|
||||
if no task is active).
|
||||
"""
|
||||
if session.current_task:
|
||||
step = session.current_task.current_step()
|
||||
if step:
|
||||
user_content = (
|
||||
f"Execute this step: {step.description}\n"
|
||||
f"Overall objective: {session.current_task.objective}"
|
||||
)
|
||||
else:
|
||||
user_content = session.current_task.objective
|
||||
else:
|
||||
user_content = "Awaiting task assignment."
|
||||
|
||||
return [{"role": "user", "content": user_content}]
|
||||
137
src/main.py
Normal file
137
src/main.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Agentic Microservice — FastAPI application entry point.
|
||||
|
||||
Wires together all components: Redis storage, model adapters, MCP client,
|
||||
context engine, orchestrator, and SSE streaming.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import pathlib
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from .adapters.claude_adapter import ClaudeAdapter
|
||||
from .adapters.openai_adapter import OpenAIAdapter
|
||||
from .api.routes import router, set_dependencies
|
||||
from .config import settings
|
||||
from .context.engine import ContextEngine
|
||||
from .mcp.client import MCPClient
|
||||
from .memory.store import MemoryStore
|
||||
from .orchestrator.engine import OrchestratorEngine
|
||||
from .storage.redis import RedisStorage
|
||||
from .streaming.sse import SSEEmitter
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if settings.debug else logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global instances (initialized in lifespan)
|
||||
redis_storage = RedisStorage()
|
||||
mcp_client = MCPClient()
|
||||
sse_emitter = SSEEmitter(redis_storage=redis_storage)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifecycle: startup and shutdown."""
|
||||
logger.info("Starting %s v%s", settings.service_name, settings.service_version)
|
||||
|
||||
# 1. Connect Redis
|
||||
await redis_storage.connect()
|
||||
|
||||
# Wire SSE emitter to Redis for event persistence (re-set after connect)
|
||||
sse_emitter.set_storage(redis_storage)
|
||||
|
||||
# 2. Initialize model adapter (based on configured provider)
|
||||
if settings.default_model_provider == "openai":
|
||||
model_adapter = OpenAIAdapter()
|
||||
logger.info("Using OpenAI adapter (model: %s)", settings.default_model_id)
|
||||
else:
|
||||
model_adapter = ClaudeAdapter()
|
||||
logger.info("Using Claude adapter (model: %s)", settings.default_model_id)
|
||||
|
||||
# 3. Initialize memory store (uses same Redis connection)
|
||||
memory_store = MemoryStore(redis_storage.client)
|
||||
|
||||
# 4. Initialize context engine (with memory store for knowledge base)
|
||||
context_engine = ContextEngine(memory_store=memory_store)
|
||||
|
||||
# 5. Start MCP client (if configured)
|
||||
if settings.mcp_server_command:
|
||||
try:
|
||||
await mcp_client.start()
|
||||
logger.info("MCP client started with %d tools", len(mcp_client.tools))
|
||||
except Exception as e:
|
||||
logger.warning("MCP client failed to start: %s — continuing without MCP", e)
|
||||
|
||||
# 6. Initialize orchestrator
|
||||
orchestrator = OrchestratorEngine(
|
||||
model_adapter=model_adapter,
|
||||
context_engine=context_engine,
|
||||
mcp_client=mcp_client,
|
||||
memory_store=memory_store,
|
||||
sse_emitter=sse_emitter,
|
||||
)
|
||||
|
||||
# 7. Wire dependencies into API routes
|
||||
set_dependencies(
|
||||
storage=redis_storage,
|
||||
orchestrator=orchestrator,
|
||||
sse_emitter=sse_emitter,
|
||||
context_engine=context_engine,
|
||||
memory_store=memory_store,
|
||||
)
|
||||
|
||||
logger.info("All systems initialized. Serving on %s:%d", settings.host, settings.port)
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down...")
|
||||
await mcp_client.stop()
|
||||
await redis_storage.disconnect()
|
||||
logger.info("Shutdown complete.")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.service_name,
|
||||
version=settings.service_version,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount API routes
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
# Health check
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok", "service": settings.service_name}
|
||||
|
||||
|
||||
# Root redirect
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return RedirectResponse(url="/dashboard/")
|
||||
|
||||
|
||||
# Dashboard static files (mounted AFTER API routes)
|
||||
_dashboard_dir = pathlib.Path(__file__).resolve().parent.parent / "dashboard"
|
||||
if _dashboard_dir.is_dir():
|
||||
app.mount("/dashboard", StaticFiles(directory=str(_dashboard_dir), html=True), name="dashboard")
|
||||
3
src/mcp/__init__.py
Normal file
3
src/mcp/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .client import MCPClient
|
||||
|
||||
__all__ = ["MCPClient"]
|
||||
BIN
src/mcp/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/mcp/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/mcp/__pycache__/client.cpython-312.pyc
Normal file
BIN
src/mcp/__pycache__/client.cpython-312.pyc
Normal file
Binary file not shown.
291
src/mcp/client.py
Normal file
291
src/mcp/client.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""MCP (Model Context Protocol) client — stdio transport.
|
||||
|
||||
Manages subprocess lifecycle, JSON-RPC request/response, timeouts,
|
||||
and a tool registry populated from the server's capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ..config import settings
|
||||
from ..models.tools import ToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Stdio-based MCP client with full lifecycle management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str | None = None,
|
||||
args: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
startup_timeout: float | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self._command = command or settings.mcp_server_command
|
||||
self._args = args if args is not None else list(settings.mcp_server_args)
|
||||
self._timeout = timeout or settings.mcp_timeout_seconds
|
||||
self._startup_timeout = startup_timeout or settings.mcp_startup_timeout_seconds
|
||||
# Inherit current env + any overrides (passes ACAI_* vars to MCP server)
|
||||
self._env = {**os.environ, **(env or {})}
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
self._tools: dict[str, ToolDefinition] = {}
|
||||
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self._reader_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, ToolDefinition]:
|
||||
return dict(self._tools)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running and self._process is not None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the MCP server subprocess and discover tools."""
|
||||
if not self._command:
|
||||
logger.warning("No MCP server command configured — skipping start")
|
||||
return
|
||||
|
||||
logger.info("Starting MCP server: %s %s", self._command, self._args)
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
self._command,
|
||||
*self._args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=self._env,
|
||||
)
|
||||
self._running = True
|
||||
self._reader_task = asyncio.create_task(self._read_loop())
|
||||
|
||||
# Initialize
|
||||
try:
|
||||
init_result = await asyncio.wait_for(
|
||||
self._send_request("initialize", {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "agentic-microservice", "version": "1.0.0"},
|
||||
}),
|
||||
timeout=self._startup_timeout,
|
||||
)
|
||||
logger.info("MCP initialized: %s", init_result)
|
||||
|
||||
# Send initialized notification
|
||||
await self._send_notification("notifications/initialized", {})
|
||||
|
||||
# Discover tools
|
||||
tools_result = await asyncio.wait_for(
|
||||
self._send_request("tools/list", {}),
|
||||
timeout=self._startup_timeout,
|
||||
)
|
||||
self._register_tools(tools_result)
|
||||
logger.info("Discovered %d MCP tools", len(self._tools))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("MCP server startup timed out")
|
||||
await self.stop()
|
||||
raise MCPClientError("MCP server startup timed out")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully stop the MCP server."""
|
||||
self._running = False
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._process:
|
||||
try:
|
||||
if self._process.stdin:
|
||||
self._process.stdin.close()
|
||||
self._process.terminate()
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
self._process.kill()
|
||||
self._process = None
|
||||
|
||||
# Cancel any pending requests
|
||||
for fut in self._pending.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._pending.clear()
|
||||
self._tools.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Call a tool on the MCP server with timeout."""
|
||||
if not self.is_running:
|
||||
raise MCPClientError("MCP client is not running")
|
||||
|
||||
if tool_name not in self._tools:
|
||||
raise MCPClientError(f"Unknown tool: {tool_name}")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._send_request("tools/call", {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
}),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPClientError(
|
||||
f"Tool '{tool_name}' timed out after {self._timeout}s"
|
||||
)
|
||||
|
||||
def get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Return tool definitions in a format suitable for model adapters."""
|
||||
definitions: list[dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
definitions.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.input_schema,
|
||||
})
|
||||
return definitions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JSON-RPC transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_request(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Send a JSON-RPC request and await the response."""
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[dict[str, Any]] = loop.create_future()
|
||||
self._pending[request_id] = future
|
||||
|
||||
await self._write_message(message)
|
||||
|
||||
try:
|
||||
return await future
|
||||
finally:
|
||||
self._pending.pop(request_id, None)
|
||||
|
||||
async def _send_notification(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a JSON-RPC notification (no response expected)."""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
await self._write_message(message)
|
||||
|
||||
async def _write_message(self, message: dict[str, Any]) -> None:
|
||||
"""Write a JSON-RPC message to the server's stdin."""
|
||||
if not self._process or not self._process.stdin:
|
||||
raise MCPClientError("MCP process stdin not available")
|
||||
|
||||
data = json.dumps(message) + "\n"
|
||||
self._process.stdin.write(data.encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
"""Continuously read JSON-RPC responses from stdout."""
|
||||
if not self._process or not self._process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
logger.warning("MCP server stdout closed")
|
||||
break
|
||||
|
||||
line_str = line.decode().strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
message = json.loads(line_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Non-JSON MCP output: %s", line_str[:200])
|
||||
continue
|
||||
|
||||
self._handle_message(message)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("MCP read loop error")
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
def _handle_message(self, message: dict[str, Any]) -> None:
|
||||
"""Route an incoming JSON-RPC message."""
|
||||
msg_id = message.get("id")
|
||||
|
||||
if msg_id and msg_id in self._pending:
|
||||
future = self._pending[msg_id]
|
||||
if future.done():
|
||||
return
|
||||
|
||||
if "error" in message:
|
||||
future.set_exception(
|
||||
MCPClientError(
|
||||
f"MCP error {message['error'].get('code')}: "
|
||||
f"{message['error'].get('message')}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
future.set_result(message.get("result", {}))
|
||||
elif "method" in message:
|
||||
# Server-initiated notification — log it
|
||||
logger.debug(
|
||||
"MCP notification: %s", message.get("method")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool registry
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _register_tools(self, tools_result: dict[str, Any]) -> None:
|
||||
"""Parse tools/list response and populate the registry."""
|
||||
raw_tools = tools_result.get("tools", [])
|
||||
for t in raw_tools:
|
||||
name = t.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
self._tools[name] = ToolDefinition(
|
||||
name=name,
|
||||
description=t.get("description", ""),
|
||||
input_schema=t.get("inputSchema", {}),
|
||||
server_name="mcp",
|
||||
)
|
||||
3
src/memory/__init__.py
Normal file
3
src/memory/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .store import MemoryStore
|
||||
|
||||
__all__ = ["MemoryStore"]
|
||||
BIN
src/memory/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/memory/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/memory/__pycache__/store.cpython-312.pyc
Normal file
BIN
src/memory/__pycache__/store.cpython-312.pyc
Normal file
Binary file not shown.
190
src/memory/store.py
Normal file
190
src/memory/store.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Persistent memory store backed by Redis.
|
||||
|
||||
Supports three memory tiers:
|
||||
1. Rules/documents — persistent, always loaded
|
||||
2. Artifact summaries — per-session, loaded on demand
|
||||
3. Optional embeddings — for semantic search
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from ..config import settings
|
||||
from ..models.artifacts import ArtifactSummary
|
||||
from ..models.context import MemoryDocument, MemoryType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Async memory store with Redis backend."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis) -> None:
|
||||
self._r = redis_client
|
||||
self._prefix = settings.redis_key_prefix
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, *parts: str) -> str:
|
||||
return ":".join([self._prefix, *parts])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rules & documents (persistent memory)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_document(self, doc: MemoryDocument) -> None:
|
||||
key = self._key("memory", doc.namespace, doc.memory_id)
|
||||
await self._r.set(key, doc.model_dump_json())
|
||||
# Index by namespace
|
||||
await self._r.sadd(self._key("memory", doc.namespace, "_index"), doc.memory_id)
|
||||
# Index by type
|
||||
await self._r.sadd(
|
||||
self._key("memory", "_type", doc.memory_type.value), doc.memory_id
|
||||
)
|
||||
# Index by tags
|
||||
for tag in doc.tags:
|
||||
await self._r.sadd(self._key("memory", "_tag", tag), doc.memory_id)
|
||||
|
||||
async def get_document(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> MemoryDocument | None:
|
||||
key = self._key("memory", namespace, memory_id)
|
||||
data = await self._r.get(key)
|
||||
if data:
|
||||
return MemoryDocument.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
namespace: str = "global",
|
||||
memory_type: MemoryType | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> list[MemoryDocument]:
|
||||
"""List documents with optional type/tag filters."""
|
||||
# Start with namespace index
|
||||
ids = await self._r.smembers(self._key("memory", namespace, "_index"))
|
||||
id_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in ids}
|
||||
|
||||
# Intersect with type filter
|
||||
if memory_type:
|
||||
type_ids = await self._r.smembers(
|
||||
self._key("memory", "_type", memory_type.value)
|
||||
)
|
||||
type_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in type_ids}
|
||||
id_set &= type_set
|
||||
|
||||
# Intersect with tag filter
|
||||
if tags:
|
||||
for tag in tags:
|
||||
tag_ids = await self._r.smembers(self._key("memory", "_tag", tag))
|
||||
tag_set = {
|
||||
mid.decode() if isinstance(mid, bytes) else mid for mid in tag_ids
|
||||
}
|
||||
id_set &= tag_set
|
||||
|
||||
docs: list[MemoryDocument] = []
|
||||
for mid in id_set:
|
||||
doc = await self.get_document(mid, namespace)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
async def delete_document(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> bool:
|
||||
key = self._key("memory", namespace, memory_id)
|
||||
deleted = await self._r.delete(key)
|
||||
await self._r.srem(self._key("memory", namespace, "_index"), memory_id)
|
||||
return bool(deleted)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Artifact summaries (per-session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_artifact(
|
||||
self, session_id: str, artifact: ArtifactSummary
|
||||
) -> None:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
await self._r.hset(key, artifact.artifact_id, artifact.model_dump_json())
|
||||
await self._r.expire(key, settings.session_ttl_seconds)
|
||||
|
||||
async def get_artifact(
|
||||
self, session_id: str, artifact_id: str
|
||||
) -> ArtifactSummary | None:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
data = await self._r.hget(key, artifact_id)
|
||||
if data:
|
||||
return ArtifactSummary.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def list_artifacts(self, session_id: str) -> list[ArtifactSummary]:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
all_data = await self._r.hgetall(key)
|
||||
return [
|
||||
ArtifactSummary.model_validate_json(v) for v in all_data.values()
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional embeddings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_embedding(
|
||||
self, memory_id: str, embedding: list[float], namespace: str = "global"
|
||||
) -> None:
|
||||
"""Store an embedding vector for a memory document."""
|
||||
key = self._key("embeddings", namespace, memory_id)
|
||||
await self._r.set(key, json.dumps(embedding))
|
||||
|
||||
async def get_embedding(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> list[float] | None:
|
||||
key = self._key("embeddings", namespace, memory_id)
|
||||
data = await self._r.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
async def search_by_similarity(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
namespace: str = "global",
|
||||
top_k: int = 5,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Brute-force cosine similarity search over stored embeddings.
|
||||
|
||||
For production, swap this with Redis Vector Search (RediSearch)
|
||||
or a dedicated vector DB.
|
||||
"""
|
||||
pattern = self._key("embeddings", namespace, "*")
|
||||
results: list[tuple[str, float]] = []
|
||||
|
||||
async for key in self._r.scan_iter(match=pattern, count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
memory_id = key_str.rsplit(":", 1)[-1]
|
||||
data = await self._r.get(key)
|
||||
if not data:
|
||||
continue
|
||||
stored = json.loads(data)
|
||||
score = self._cosine_similarity(query_embedding, stored)
|
||||
results.append((memory_id, score))
|
||||
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
if len(a) != len(b) or not a:
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = sum(x * x for x in a) ** 0.5
|
||||
mag_b = sum(x * x for x in b) ** 0.5
|
||||
if mag_a == 0 or mag_b == 0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
19
src/models/__init__.py
Normal file
19
src/models/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .session import SessionState, TaskState
|
||||
from .context import ContextPackage, MemoryDocument, ContextSection
|
||||
from .agent import AgentProfile, SubAgentDefinition, AgentRole
|
||||
from .artifacts import ArtifactSummary
|
||||
from .tools import ToolExecution, ToolDefinition
|
||||
|
||||
__all__ = [
|
||||
"SessionState",
|
||||
"TaskState",
|
||||
"ContextPackage",
|
||||
"MemoryDocument",
|
||||
"ContextSection",
|
||||
"AgentProfile",
|
||||
"SubAgentDefinition",
|
||||
"AgentRole",
|
||||
"ArtifactSummary",
|
||||
"ToolExecution",
|
||||
"ToolDefinition",
|
||||
]
|
||||
BIN
src/models/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/agent.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/agent.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/artifacts.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/artifacts.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/context.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/context.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/session.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/session.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/models/__pycache__/tools.cpython-312.pyc
Normal file
BIN
src/models/__pycache__/tools.cpython-312.pyc
Normal file
Binary file not shown.
49
src/models/agent.py
Normal file
49
src/models/agent.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Agent profile and subagent definition models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AgentRole(StrEnum):
|
||||
ORCHESTRATOR = "orchestrator"
|
||||
PLANNER = "planner"
|
||||
CODER = "coder"
|
||||
COLLECTOR = "collector"
|
||||
REVIEWER = "reviewer"
|
||||
|
||||
|
||||
class AgentProfile(BaseModel):
|
||||
"""Describes the identity and capabilities of an agent."""
|
||||
|
||||
role: AgentRole
|
||||
name: str
|
||||
system_prompt: str
|
||||
allowed_tools: list[str] = Field(default_factory=list)
|
||||
model_id: str | None = None
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
context_sections: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"immutable_rules",
|
||||
"project_profile",
|
||||
"knowledge_base",
|
||||
"task_state",
|
||||
"working_context",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SubAgentDefinition(BaseModel):
|
||||
"""A runnable subagent configuration within the orchestrator."""
|
||||
|
||||
agent_id: str
|
||||
profile: AgentProfile
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
output_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
max_steps: int = 10
|
||||
requires_approval: bool = False
|
||||
description: str = ""
|
||||
28
src/models/artifacts.py
Normal file
28
src/models/artifacts.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Artifact summary model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ArtifactSummary(BaseModel):
|
||||
"""A summarised artifact produced during execution.
|
||||
|
||||
Raw content is NEVER sent to the model. Only the summary is included
|
||||
in the context package.
|
||||
"""
|
||||
|
||||
artifact_id: str
|
||||
session_id: str
|
||||
task_id: str
|
||||
artifact_type: str # e.g. "code", "analysis", "plan", "test_result"
|
||||
title: str
|
||||
summary: str # Compact human-readable summary
|
||||
facts: list[str] = Field(default_factory=list) # Extracted factual claims
|
||||
source_tool: str = "" # Which tool produced this
|
||||
char_count: int = 0 # Size of the original content
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
67
src/models/context.py
Normal file
67
src/models/context.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Context package and memory document models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContextSectionType(StrEnum):
|
||||
IMMUTABLE_RULES = "immutable_rules"
|
||||
PROJECT_PROFILE = "project_profile"
|
||||
KNOWLEDGE_BASE = "knowledge_base"
|
||||
TASK_STATE = "task_state"
|
||||
ARTIFACT_MEMORY = "artifact_memory"
|
||||
WORKING_CONTEXT = "working_context"
|
||||
|
||||
|
||||
class ContextSection(BaseModel):
|
||||
"""A discrete section of the assembled context."""
|
||||
|
||||
section_type: ContextSectionType
|
||||
content: str
|
||||
priority: int = 0 # Higher = more important, kept during compaction
|
||||
token_estimate: int = 0
|
||||
|
||||
|
||||
class ContextPackage(BaseModel):
|
||||
"""The fully assembled context sent to the model. Never includes raw tool output."""
|
||||
|
||||
sections: list[ContextSection] = Field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
messages: list[dict[str, Any]] = Field(default_factory=list)
|
||||
total_token_estimate: int = 0
|
||||
|
||||
def to_messages(self) -> list[dict[str, Any]]:
|
||||
"""Produce the final messages list for the model adapter."""
|
||||
result: list[dict[str, Any]] = []
|
||||
if self.system_prompt:
|
||||
result.append({"role": "system", "content": self.system_prompt})
|
||||
result.extend(self.messages)
|
||||
return result
|
||||
|
||||
|
||||
class MemoryType(StrEnum):
|
||||
RULE = "rule"
|
||||
DOCUMENT = "document"
|
||||
ARTIFACT = "artifact"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
class MemoryDocument(BaseModel):
|
||||
"""A single piece of persistent memory."""
|
||||
|
||||
memory_id: str
|
||||
memory_type: MemoryType
|
||||
namespace: str = "global"
|
||||
title: str
|
||||
content: str
|
||||
summary: str = ""
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
embedding: list[float] | None = None
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
107
src/models/session.py
Normal file
107
src/models/session.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Session and task state models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SessionStatus(StrEnum):
|
||||
ACTIVE = "active"
|
||||
IDLE = "idle"
|
||||
EXECUTING = "executing"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class TaskStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
PLANNING = "planning"
|
||||
EXECUTING = "executing"
|
||||
REVIEWING = "reviewing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskStep(BaseModel):
|
||||
"""A single planned step inside a task."""
|
||||
|
||||
step_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:8])
|
||||
description: str
|
||||
agent_role: str = "coder"
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
result_summary: str = ""
|
||||
tools_used: list[str] = Field(default_factory=list)
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""Represents the current task being worked on within a session."""
|
||||
|
||||
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
objective: str
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
plan: list[TaskStep] = Field(default_factory=list)
|
||||
current_step_index: int = 0
|
||||
facts_extracted: list[str] = Field(default_factory=list)
|
||||
constraints: list[str] = Field(default_factory=list)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def current_step(self) -> TaskStep | None:
|
||||
if 0 <= self.current_step_index < len(self.plan):
|
||||
return self.plan[self.current_step_index]
|
||||
return None
|
||||
|
||||
def advance(self) -> bool:
|
||||
step = self.current_step()
|
||||
if step:
|
||||
step.status = TaskStatus.COMPLETED
|
||||
step.completed_at = datetime.now(timezone.utc)
|
||||
self.current_step_index += 1
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
return self.current_step_index < len(self.plan)
|
||||
|
||||
def mark_failed(self, reason: str) -> None:
|
||||
self.status = TaskStatus.FAILED
|
||||
step = self.current_step()
|
||||
if step:
|
||||
step.status = TaskStatus.FAILED
|
||||
step.result_summary = reason
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class SessionState(BaseModel):
|
||||
"""Top-level session state persisted in Redis."""
|
||||
|
||||
session_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
|
||||
status: SessionStatus = SessionStatus.IDLE
|
||||
project_profile: dict[str, Any] = Field(default_factory=dict)
|
||||
immutable_rules: list[str] = Field(default_factory=list)
|
||||
current_task: TaskState | None = None
|
||||
completed_tasks: list[str] = Field(default_factory=list)
|
||||
turn_count: int = 0
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def begin_task(self, objective: str) -> TaskState:
|
||||
task = TaskState(objective=objective)
|
||||
self.current_task = task
|
||||
self.status = SessionStatus.EXECUTING
|
||||
self.turn_count += 1
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
return task
|
||||
|
||||
def complete_task(self) -> None:
|
||||
if self.current_task:
|
||||
self.current_task.status = TaskStatus.COMPLETED
|
||||
self.completed_tasks.append(self.current_task.task_id)
|
||||
self.current_task = None
|
||||
self.status = SessionStatus.IDLE
|
||||
self.updated_at = datetime.now(timezone.utc)
|
||||
41
src/models/tools.py
Normal file
41
src/models/tools.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Tool execution and definition models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ToolExecutionStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
TIMEOUT = "timeout"
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""Schema describing an available tool."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
server_name: str = "" # MCP server that provides this tool
|
||||
|
||||
|
||||
class ToolExecution(BaseModel):
|
||||
"""Record of a single tool invocation."""
|
||||
|
||||
execution_id: str
|
||||
tool_name: str
|
||||
arguments: dict[str, Any] = Field(default_factory=dict)
|
||||
status: ToolExecutionStatus = ToolExecutionStatus.PENDING
|
||||
result_summary: str = "" # Summarised result — raw output is NEVER stored here
|
||||
error: str = ""
|
||||
duration_ms: float = 0.0
|
||||
started_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
completed_at: datetime | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
3
src/orchestrator/__init__.py
Normal file
3
src/orchestrator/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .engine import OrchestratorEngine
|
||||
|
||||
__all__ = ["OrchestratorEngine"]
|
||||
BIN
src/orchestrator/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/orchestrator/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/__pycache__/engine.cpython-312.pyc
Normal file
BIN
src/orchestrator/__pycache__/engine.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/__pycache__/router.cpython-312.pyc
Normal file
BIN
src/orchestrator/__pycache__/router.cpython-312.pyc
Normal file
Binary file not shown.
6
src/orchestrator/agents/__init__.py
Normal file
6
src/orchestrator/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from .planner import PlannerAgent
|
||||
from .coder import CoderAgent
|
||||
from .collector import CollectorAgent
|
||||
from .reviewer import ReviewerAgent
|
||||
|
||||
__all__ = ["PlannerAgent", "CoderAgent", "CollectorAgent", "ReviewerAgent"]
|
||||
BIN
src/orchestrator/agents/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/agents/__pycache__/base.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/agents/__pycache__/coder.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/coder.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/agents/__pycache__/collector.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/collector.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/agents/__pycache__/planner.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/planner.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/orchestrator/agents/__pycache__/reviewer.cpython-312.pyc
Normal file
BIN
src/orchestrator/agents/__pycache__/reviewer.cpython-312.pyc
Normal file
Binary file not shown.
241
src/orchestrator/agents/base.py
Normal file
241
src/orchestrator/agents/base.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Base subagent class with shared execution logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from ...adapters.base import ModelAdapter, ModelConfig, StreamChunk
|
||||
from ...context.engine import ContextEngine
|
||||
from ...mcp.client import MCPClient
|
||||
from ...memory.store import MemoryStore
|
||||
from ...models.agent import AgentProfile
|
||||
from ...models.artifacts import ArtifactSummary
|
||||
from ...models.session import SessionState
|
||||
from ...models.tools import ToolExecution, ToolExecutionStatus
|
||||
from ...streaming.sse import SSEEmitter, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""Base class for all subagents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
profile: AgentProfile,
|
||||
model_adapter: ModelAdapter,
|
||||
context_engine: ContextEngine,
|
||||
mcp_client: MCPClient,
|
||||
memory_store: MemoryStore,
|
||||
sse_emitter: SSEEmitter,
|
||||
) -> None:
|
||||
self.profile = profile
|
||||
self.model = model_adapter
|
||||
self.context = context_engine
|
||||
self.mcp = mcp_client
|
||||
self.memory = memory_store
|
||||
self.sse = sse_emitter
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: SessionState,
|
||||
max_steps: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Run the agent's execution loop.
|
||||
|
||||
Returns a result dict with keys: content, artifacts, tool_executions.
|
||||
"""
|
||||
artifacts: list[ArtifactSummary] = await self.memory.list_artifacts(
|
||||
session.session_id
|
||||
)
|
||||
tool_executions: list[ToolExecution] = []
|
||||
accumulated_content = ""
|
||||
working_items: list[dict[str, Any]] = []
|
||||
|
||||
for step in range(max_steps):
|
||||
# Build context — NEVER includes raw tool output
|
||||
ctx = await self.context.build_context(
|
||||
session=session,
|
||||
agent=self.profile,
|
||||
artifacts=artifacts,
|
||||
working_items=working_items,
|
||||
)
|
||||
|
||||
# Prepare tool definitions
|
||||
tool_defs = self._get_allowed_tools()
|
||||
|
||||
# Stream model response
|
||||
config = ModelConfig(
|
||||
model_id=self.profile.model_id or "",
|
||||
max_tokens=self.profile.max_tokens or 4096,
|
||||
temperature=self.profile.temperature or 0.3,
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
current_tool: dict[str, Any] = {}
|
||||
|
||||
async for chunk in self.model.stream(
|
||||
messages=ctx.to_messages(),
|
||||
tools=tool_defs if tool_defs else None,
|
||||
config=config,
|
||||
):
|
||||
if chunk.delta:
|
||||
full_text += chunk.delta
|
||||
await self.sse.emit(
|
||||
EventType.AGENT_DELTA,
|
||||
{
|
||||
"agent": self.profile.role,
|
||||
"delta": chunk.delta,
|
||||
"step": step,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
if chunk.tool_name and not current_tool.get("name"):
|
||||
current_tool = {
|
||||
"id": chunk.tool_call_id,
|
||||
"name": chunk.tool_name,
|
||||
"arguments": "",
|
||||
}
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_STARTED,
|
||||
{"tool": chunk.tool_name, "step": step},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
if chunk.tool_arguments and current_tool:
|
||||
current_tool["arguments"] += chunk.tool_arguments
|
||||
|
||||
if chunk.finish_reason == "tool_use" and current_tool.get("name"):
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(current_tool["arguments"]) if current_tool["arguments"] else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
current_tool["parsed_arguments"] = args
|
||||
tool_calls.append(current_tool)
|
||||
current_tool = {}
|
||||
|
||||
if chunk.finish_reason == "end_turn":
|
||||
break
|
||||
|
||||
accumulated_content += full_text
|
||||
|
||||
# If no tool calls, we're done
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute tool calls
|
||||
for tc in tool_calls:
|
||||
tool_exec = await self._execute_tool(
|
||||
session=session,
|
||||
tool_name=tc["name"],
|
||||
arguments=tc.get("parsed_arguments", {}),
|
||||
artifacts=artifacts,
|
||||
)
|
||||
tool_executions.append(tool_exec)
|
||||
|
||||
# Add summarised result to working context (NEVER raw)
|
||||
working_items.append({
|
||||
"role": "tool_result",
|
||||
"content": f"[{tc['name']}] {tool_exec.result_summary}",
|
||||
})
|
||||
|
||||
return {
|
||||
"content": accumulated_content,
|
||||
"artifacts": artifacts,
|
||||
"tool_executions": tool_executions,
|
||||
}
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
session: SessionState,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
artifacts: list[ArtifactSummary],
|
||||
) -> ToolExecution:
|
||||
"""Execute a tool and summarise the result."""
|
||||
exec_id = uuid.uuid4().hex[:12]
|
||||
tool_exec = ToolExecution(
|
||||
execution_id=exec_id,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
status=ToolExecutionStatus.RUNNING,
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
if self.mcp.is_running and tool_name in self.mcp.tools:
|
||||
result = await self.mcp.call_tool(tool_name, arguments)
|
||||
raw_output = self._extract_mcp_output(result)
|
||||
else:
|
||||
raw_output = f"Tool '{tool_name}' not available via MCP."
|
||||
|
||||
duration = (time.monotonic() - start) * 1000
|
||||
|
||||
# Summarise — raw output NEVER enters context
|
||||
task_id = session.current_task.task_id if session.current_task else "none"
|
||||
artifact = self.context.summarize_tool_output(
|
||||
tool_name=tool_name,
|
||||
raw_output=raw_output,
|
||||
session_id=session.session_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Store artifact
|
||||
await self.memory.store_artifact(session.session_id, artifact)
|
||||
artifacts.append(artifact)
|
||||
|
||||
tool_exec.status = ToolExecutionStatus.COMPLETED
|
||||
tool_exec.result_summary = artifact.summary
|
||||
tool_exec.duration_ms = duration
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_COMPLETED,
|
||||
{
|
||||
"tool": tool_name,
|
||||
"status": "completed",
|
||||
"summary": artifact.summary[:200],
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
tool_exec.status = ToolExecutionStatus.FAILED
|
||||
tool_exec.error = str(e)
|
||||
tool_exec.duration_ms = (time.monotonic() - start) * 1000
|
||||
logger.error("Tool execution failed: %s — %s", tool_name, e)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_COMPLETED,
|
||||
{"tool": tool_name, "status": "failed", "error": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return tool_exec
|
||||
|
||||
def _get_allowed_tools(self) -> list[dict[str, Any]]:
|
||||
"""Return tool definitions filtered by this agent's allowed_tools."""
|
||||
if not self.mcp.is_running:
|
||||
return []
|
||||
all_tools = self.mcp.get_tool_definitions()
|
||||
if not self.profile.allowed_tools:
|
||||
return all_tools # No filter → all tools
|
||||
return [t for t in all_tools if t["name"] in self.profile.allowed_tools]
|
||||
|
||||
@staticmethod
|
||||
def _extract_mcp_output(result: dict[str, Any]) -> str:
|
||||
"""Extract text content from MCP tool result."""
|
||||
content = result.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(item.get("text", ""))
|
||||
return "\n".join(parts) if parts else json.dumps(result)
|
||||
return str(content)
|
||||
46
src/orchestrator/agents/coder.py
Normal file
46
src/orchestrator/agents/coder.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Coder agent — executes implementation steps using tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...models.agent import AgentProfile, AgentRole
|
||||
from .base import BaseAgent
|
||||
|
||||
CODER_SYSTEM_PROMPT = """Eres un Agente Programador. Tu rol es ejecutar tareas de implementación usando las herramientas disponibles.
|
||||
|
||||
## Instrucciones
|
||||
- Concéntrate en la descripción del paso actual.
|
||||
- Usa herramientas para lograr la tarea.
|
||||
- Sé preciso y minucioso.
|
||||
- Reporta lo que lograste, problemas encontrados y hechos relevantes.
|
||||
- NO produzcas explicaciones innecesarias — produce resultados.
|
||||
- Responde SIEMPRE en español.
|
||||
|
||||
## Uso de herramientas
|
||||
- Usa herramientas cuando necesites leer archivos, escribir código o ejecutar comandos.
|
||||
- Los resultados de herramientas se te presentarán resumidos — no verás la salida cruda.
|
||||
- Si necesitas más detalle de un resultado, solicita rehidratación.
|
||||
"""
|
||||
|
||||
|
||||
def create_coder_profile() -> AgentProfile:
|
||||
return AgentProfile(
|
||||
role=AgentRole.CODER,
|
||||
name="coder",
|
||||
system_prompt=CODER_SYSTEM_PROMPT,
|
||||
allowed_tools=[], # All tools allowed
|
||||
temperature=0.2,
|
||||
max_tokens=4096,
|
||||
context_sections=[
|
||||
"immutable_rules",
|
||||
"project_profile",
|
||||
"knowledge_base",
|
||||
"task_state",
|
||||
"artifact_memory",
|
||||
"working_context",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class CoderAgent(BaseAgent):
|
||||
"""Executes implementation steps."""
|
||||
pass
|
||||
46
src/orchestrator/agents/collector.py
Normal file
46
src/orchestrator/agents/collector.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Collector agent — gathers context and information."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...models.agent import AgentProfile, AgentRole
|
||||
from .base import BaseAgent
|
||||
|
||||
COLLECTOR_SYSTEM_PROMPT = """Eres un Agente Recolector de Contexto. Tu rol es recopilar información necesaria para una tarea.
|
||||
|
||||
## Instrucciones
|
||||
- Lee archivos, busca en el código, explora documentación.
|
||||
- Produce un resumen estructurado de lo que encontraste.
|
||||
- Extrae hechos clave, restricciones y dependencias.
|
||||
- NO modifiques nada — solo observa y reporta.
|
||||
- Responde SIEMPRE en español.
|
||||
|
||||
## Formato de salida
|
||||
Produce un resumen estructurado:
|
||||
1. Archivos relevantes y sus propósitos
|
||||
2. Patrones o convenciones encontrados
|
||||
3. Dependencias o restricciones
|
||||
4. Recomendaciones para el paso de implementación
|
||||
"""
|
||||
|
||||
|
||||
def create_collector_profile() -> AgentProfile:
|
||||
return AgentProfile(
|
||||
role=AgentRole.COLLECTOR,
|
||||
name="collector",
|
||||
system_prompt=COLLECTOR_SYSTEM_PROMPT,
|
||||
allowed_tools=[], # All tools allowed (read-only preferred)
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
context_sections=[
|
||||
"immutable_rules",
|
||||
"project_profile",
|
||||
"knowledge_base",
|
||||
"task_state",
|
||||
"working_context",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class CollectorAgent(BaseAgent):
|
||||
"""Gathers context and information for tasks."""
|
||||
pass
|
||||
107
src/orchestrator/agents/planner.py
Normal file
107
src/orchestrator/agents/planner.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Planner agent — decomposes objectives into executable plans."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from ...models.agent import AgentProfile, AgentRole
|
||||
from ...models.session import SessionState, TaskStep, TaskStatus
|
||||
from .base import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLANNER_SYSTEM_PROMPT = """Eres un Agente Planificador. Tu rol es descomponer un objetivo en un plan de ejecución estructurado.
|
||||
|
||||
## Instrucciones
|
||||
- Analiza el objetivo y divídelo en pasos concretos y ordenados.
|
||||
- Cada paso debe ser ejecutable de forma independiente por un agente especializado.
|
||||
- Asigna cada paso al rol de agente apropiado: coder, collector o reviewer.
|
||||
- Responde SIEMPRE en español.
|
||||
|
||||
## Formato de salida
|
||||
Devuelve SOLO un objeto JSON:
|
||||
{
|
||||
"plan": [
|
||||
{"description": "descripción del paso", "agent_role": "coder|collector|reviewer"},
|
||||
...
|
||||
],
|
||||
"constraints": ["restricciones o notas importantes"],
|
||||
"facts": ["hechos establecidos del análisis"]
|
||||
}
|
||||
|
||||
NO incluyas comentarios fuera del JSON."""
|
||||
|
||||
|
||||
def create_planner_profile() -> AgentProfile:
|
||||
return AgentProfile(
|
||||
role=AgentRole.PLANNER,
|
||||
name="planner",
|
||||
system_prompt=PLANNER_SYSTEM_PROMPT,
|
||||
allowed_tools=[], # Planner doesn't use tools
|
||||
temperature=0.2,
|
||||
max_tokens=2048,
|
||||
context_sections=[
|
||||
"immutable_rules",
|
||||
"project_profile",
|
||||
"knowledge_base",
|
||||
"task_state",
|
||||
"artifact_memory",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class PlannerAgent(BaseAgent):
|
||||
"""Generates execution plans from objectives."""
|
||||
|
||||
async def plan(self, session: SessionState) -> list[TaskStep]:
|
||||
"""Generate a plan and return TaskSteps."""
|
||||
result = await self.execute(session, max_steps=1)
|
||||
content = result["content"].strip()
|
||||
|
||||
# Parse the JSON plan from the model output
|
||||
try:
|
||||
# Try to extract JSON from the content
|
||||
json_str = content
|
||||
if "```" in content:
|
||||
# Extract from code block
|
||||
start = content.find("{")
|
||||
end = content.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
json_str = content[start:end]
|
||||
|
||||
parsed = json.loads(json_str)
|
||||
steps: list[TaskStep] = []
|
||||
|
||||
for item in parsed.get("plan", []):
|
||||
steps.append(
|
||||
TaskStep(
|
||||
description=item.get("description", ""),
|
||||
agent_role=item.get("agent_role", "coder"),
|
||||
status=TaskStatus.PENDING,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract constraints and facts into task state
|
||||
if session.current_task:
|
||||
session.current_task.constraints.extend(
|
||||
parsed.get("constraints", [])
|
||||
)
|
||||
session.current_task.facts_extracted.extend(
|
||||
parsed.get("facts", [])
|
||||
)
|
||||
|
||||
return steps
|
||||
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning("Failed to parse planner output: %s", e)
|
||||
# Fallback: single step with the full objective
|
||||
return [
|
||||
TaskStep(
|
||||
description=session.current_task.objective
|
||||
if session.current_task
|
||||
else "Execute task",
|
||||
agent_role="coder",
|
||||
)
|
||||
]
|
||||
47
src/orchestrator/agents/reviewer.py
Normal file
47
src/orchestrator/agents/reviewer.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Reviewer agent — validates outputs and provides feedback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ...models.agent import AgentProfile, AgentRole
|
||||
from .base import BaseAgent
|
||||
|
||||
REVIEWER_SYSTEM_PROMPT = """Eres un Agente Revisor. Tu rol es validar el trabajo realizado por otros agentes.
|
||||
|
||||
## Instrucciones
|
||||
- Revisa los artefactos producidos en esta sesión.
|
||||
- Verifica corrección, completitud y calidad.
|
||||
- Identifica problemas, bugs o piezas faltantes.
|
||||
- Proporciona retroalimentación accionable.
|
||||
- Responde SIEMPRE en español.
|
||||
|
||||
## Formato de salida
|
||||
Produce una revisión estructurada:
|
||||
1. **Estado**: APROBADO | NECESITA_CAMBIOS | RECHAZADO
|
||||
2. **Problemas**: Lista de problemas encontrados
|
||||
3. **Sugerencias**: Mejoras a considerar
|
||||
4. **Hechos**: Nuevos hechos establecidos durante la revisión
|
||||
"""
|
||||
|
||||
|
||||
def create_reviewer_profile() -> AgentProfile:
|
||||
return AgentProfile(
|
||||
role=AgentRole.REVIEWER,
|
||||
name="reviewer",
|
||||
system_prompt=REVIEWER_SYSTEM_PROMPT,
|
||||
allowed_tools=[], # All tools allowed
|
||||
temperature=0.1,
|
||||
max_tokens=2048,
|
||||
context_sections=[
|
||||
"immutable_rules",
|
||||
"project_profile",
|
||||
"knowledge_base",
|
||||
"task_state",
|
||||
"artifact_memory",
|
||||
"working_context",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class ReviewerAgent(BaseAgent):
|
||||
"""Reviews and validates work products."""
|
||||
pass
|
||||
295
src/orchestrator/engine.py
Normal file
295
src/orchestrator/engine.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Orchestrator Engine — the main execution loop.
|
||||
|
||||
Flow: message → plan → route → execute steps → summarize → update → stream
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..adapters.base import ModelAdapter
|
||||
from ..config import settings
|
||||
from ..context.engine import ContextEngine
|
||||
from ..mcp.client import MCPClient
|
||||
from ..memory.store import MemoryStore
|
||||
from ..models.agent import AgentRole
|
||||
from ..models.session import SessionState, SessionStatus, TaskStatus
|
||||
from ..streaming.sse import SSEEmitter, EventType
|
||||
from .agents.coder import CoderAgent, create_coder_profile
|
||||
from .agents.collector import CollectorAgent, create_collector_profile
|
||||
from .agents.planner import PlannerAgent, create_planner_profile
|
||||
from .agents.reviewer import ReviewerAgent, create_reviewer_profile
|
||||
from .router import route_step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrchestratorEngine:
|
||||
"""Drives the full execution lifecycle for a session message."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_adapter: ModelAdapter,
|
||||
context_engine: ContextEngine,
|
||||
mcp_client: MCPClient,
|
||||
memory_store: MemoryStore,
|
||||
sse_emitter: SSEEmitter,
|
||||
) -> None:
|
||||
self.model = model_adapter
|
||||
self.context = context_engine
|
||||
self.mcp = mcp_client
|
||||
self.memory = memory_store
|
||||
self.sse = sse_emitter
|
||||
|
||||
# Pre-built agent profiles
|
||||
self._profiles = {
|
||||
AgentRole.PLANNER: create_planner_profile(),
|
||||
AgentRole.CODER: create_coder_profile(),
|
||||
AgentRole.COLLECTOR: create_collector_profile(),
|
||||
AgentRole.REVIEWER: create_reviewer_profile(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session: SessionState,
|
||||
message: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Process a user message through the full orchestration pipeline.
|
||||
|
||||
Pipeline: plan → execute steps → review → complete
|
||||
|
||||
Handles errors gracefully: failed steps are marked and skipped,
|
||||
the session always returns to idle/error — never stuck in executing.
|
||||
"""
|
||||
task = None
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._run_pipeline(session, message),
|
||||
timeout=settings.max_execution_timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Execution timed out for session %s", session.session_id)
|
||||
if session.current_task:
|
||||
session.current_task.mark_failed("Execution timed out")
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "execution_timeout", "message": "Task exceeded maximum execution time"},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, "Execution timed out")
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled error in pipeline for session %s", session.session_id)
|
||||
if session.current_task:
|
||||
session.current_task.mark_failed(str(e))
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "pipeline_error", "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, str(e))
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
session: SessionState,
|
||||
message: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Inner pipeline — wrapped by process_message for error handling."""
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.EXECUTION_STARTED,
|
||||
{"session_id": session.session_id, "message": message[:200]},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# 1. Create task from message
|
||||
task = session.begin_task(objective=message)
|
||||
|
||||
# 2. Plan
|
||||
task.status = TaskStatus.PLANNING
|
||||
try:
|
||||
planner = self._create_agent(AgentRole.PLANNER)
|
||||
plan_steps = await planner.plan(session)
|
||||
task.plan = plan_steps
|
||||
task.status = TaskStatus.EXECUTING
|
||||
except Exception as e:
|
||||
logger.error("Planning failed: %s", e)
|
||||
task.mark_failed(f"Planning failed: {e}")
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "planning_failed", "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, f"Planning failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"Plan created with %d steps for task %s",
|
||||
len(plan_steps),
|
||||
task.task_id,
|
||||
)
|
||||
|
||||
# 3. Execute each step — failures are logged and skipped
|
||||
results: list[dict[str, Any]] = []
|
||||
failed_steps: list[int] = []
|
||||
|
||||
for i, step in enumerate(task.plan):
|
||||
if i >= settings.max_execution_steps:
|
||||
logger.warning("Max execution steps reached")
|
||||
break
|
||||
|
||||
task.current_step_index = i
|
||||
step.status = TaskStatus.EXECUTING
|
||||
step.started_at = datetime.now(timezone.utc)
|
||||
|
||||
role = route_step(step)
|
||||
agent = self._create_agent(role)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.SUBAGENT_ASSIGNED,
|
||||
{
|
||||
"step": i + 1,
|
||||
"total_steps": len(task.plan),
|
||||
"agent": role.value,
|
||||
"description": step.description,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
step_result = await agent.execute(
|
||||
session=session,
|
||||
max_steps=settings.subagent_max_steps,
|
||||
)
|
||||
results.append(step_result)
|
||||
|
||||
step.status = TaskStatus.COMPLETED
|
||||
step.completed_at = datetime.now(timezone.utc)
|
||||
step.result_summary = (step_result.get("content", ""))[:500]
|
||||
step.tools_used = [
|
||||
te.tool_name for te in step_result.get("tool_executions", [])
|
||||
]
|
||||
|
||||
for artifact in step_result.get("artifacts", []):
|
||||
task.facts_extracted.extend(artifact.facts[:5])
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Step %d failed: %s", i + 1, e)
|
||||
step.status = TaskStatus.FAILED
|
||||
step.completed_at = datetime.now(timezone.utc)
|
||||
step.result_summary = f"Error: {e}"
|
||||
failed_steps.append(i + 1)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "step_failed", "step": i + 1, "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
# Continue with next step — don't block the pipeline
|
||||
|
||||
# 4. Review (if plan had more than 1 step and at least one succeeded)
|
||||
review_result: dict[str, Any] = {}
|
||||
if len(task.plan) > 1 and results:
|
||||
task.status = TaskStatus.REVIEWING
|
||||
try:
|
||||
reviewer = self._create_agent(AgentRole.REVIEWER)
|
||||
review_result = await reviewer.execute(
|
||||
session=session,
|
||||
max_steps=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Review failed: %s", e)
|
||||
review_result = {"content": f"Review skipped due to error: {e}"}
|
||||
|
||||
# 5. Complete — session ALWAYS returns to idle
|
||||
session.complete_task()
|
||||
|
||||
final_content = self._assemble_response(results, review_result)
|
||||
status = "completed" if not failed_steps else "partial"
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.EXECUTION_COMPLETED,
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"task_id": task.task_id,
|
||||
"steps_completed": len(results),
|
||||
"steps_failed": failed_steps,
|
||||
"status": status,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"task_id": task.task_id,
|
||||
"content": final_content,
|
||||
"steps_completed": len(results),
|
||||
"steps_failed": failed_steps,
|
||||
"artifacts_count": sum(
|
||||
len(r.get("artifacts", [])) for r in results
|
||||
),
|
||||
"review": review_result.get("content", ""),
|
||||
"status": status,
|
||||
}
|
||||
|
||||
def _error_result(self, session: SessionState, error: str) -> dict[str, Any]:
|
||||
"""Build a standardized error response."""
|
||||
task_id = session.current_task.task_id if session.current_task else "none"
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"task_id": task_id,
|
||||
"content": f"Error: {error}",
|
||||
"steps_completed": 0,
|
||||
"steps_failed": [],
|
||||
"artifacts_count": 0,
|
||||
"review": "",
|
||||
"status": "error",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _create_agent(self, role: AgentRole) -> PlannerAgent | CoderAgent | CollectorAgent | ReviewerAgent:
|
||||
"""Instantiate a subagent for the given role."""
|
||||
profile = self._profiles[role]
|
||||
agent_cls = {
|
||||
AgentRole.PLANNER: PlannerAgent,
|
||||
AgentRole.CODER: CoderAgent,
|
||||
AgentRole.COLLECTOR: CollectorAgent,
|
||||
AgentRole.REVIEWER: ReviewerAgent,
|
||||
}[role]
|
||||
|
||||
return agent_cls(
|
||||
profile=profile,
|
||||
model_adapter=self.model,
|
||||
context_engine=self.context,
|
||||
mcp_client=self.mcp,
|
||||
memory_store=self.memory,
|
||||
sse_emitter=self.sse,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _assemble_response(
|
||||
results: list[dict[str, Any]],
|
||||
review_result: dict[str, Any],
|
||||
) -> str:
|
||||
"""Combine step results into a coherent final response."""
|
||||
parts: list[str] = []
|
||||
for i, r in enumerate(results):
|
||||
content = r.get("content", "").strip()
|
||||
if content:
|
||||
parts.append(f"### Step {i + 1}\n{content}")
|
||||
|
||||
if review_result.get("content"):
|
||||
parts.append(f"### Review\n{review_result['content']}")
|
||||
|
||||
return "\n\n".join(parts) if parts else "Task completed."
|
||||
60
src/orchestrator/router.py
Normal file
60
src/orchestrator/router.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Agent router — selects the right subagent for each step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ..models.agent import AgentRole
|
||||
from ..models.session import TaskStep
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Keyword-based routing hints
|
||||
_ROLE_KEYWORDS: dict[AgentRole, list[str]] = {
|
||||
AgentRole.COLLECTOR: [
|
||||
"gather", "collect", "read", "explore", "search", "find",
|
||||
"discover", "analyze", "investigate", "research", "scan",
|
||||
"understand", "review existing",
|
||||
],
|
||||
AgentRole.CODER: [
|
||||
"implement", "write", "create", "build", "code", "fix",
|
||||
"modify", "refactor", "add", "update", "generate", "develop",
|
||||
"edit", "change", "configure", "set up",
|
||||
],
|
||||
AgentRole.REVIEWER: [
|
||||
"review", "validate", "check", "verify", "test", "audit",
|
||||
"inspect", "evaluate", "assess", "confirm",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def route_step(step: TaskStep) -> AgentRole:
|
||||
"""Determine which agent role should handle this step.
|
||||
|
||||
Uses the step's declared agent_role if valid, otherwise falls back
|
||||
to keyword-based routing.
|
||||
"""
|
||||
# Respect explicit assignment
|
||||
declared = step.agent_role.lower()
|
||||
try:
|
||||
role = AgentRole(declared)
|
||||
return role
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Keyword-based fallback
|
||||
desc_lower = step.description.lower()
|
||||
scores: dict[AgentRole, int] = {role: 0 for role in _ROLE_KEYWORDS}
|
||||
|
||||
for role, keywords in _ROLE_KEYWORDS.items():
|
||||
for kw in keywords:
|
||||
if kw in desc_lower:
|
||||
scores[role] += 1
|
||||
|
||||
best = max(scores, key=lambda r: scores[r])
|
||||
if scores[best] > 0:
|
||||
logger.info("Routed step '%s' to %s (score=%d)", step.description[:60], best, scores[best])
|
||||
return best
|
||||
|
||||
# Default to coder
|
||||
return AgentRole.CODER
|
||||
3
src/storage/__init__.py
Normal file
3
src/storage/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .redis import RedisStorage
|
||||
|
||||
__all__ = ["RedisStorage"]
|
||||
BIN
src/storage/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/storage/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/storage/__pycache__/redis.cpython-312.pyc
Normal file
BIN
src/storage/__pycache__/redis.cpython-312.pyc
Normal file
Binary file not shown.
151
src/storage/redis.py
Normal file
151
src/storage/redis.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Redis storage layer for session persistence.
|
||||
|
||||
Key structure:
|
||||
{prefix}:session:{id} — SessionState JSON
|
||||
{prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id
|
||||
{prefix}:session:{id}:events — List of SSE event JSONs
|
||||
{prefix}:session:{id}:lock — Execution lock (SETNX)
|
||||
{prefix}:sessions:index — Set of all active session IDs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from ..config import settings
|
||||
from ..models.session import SessionState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisStorage:
|
||||
"""Async Redis storage for session state."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis | None = None) -> None:
|
||||
self._r = redis_client
|
||||
self._prefix = settings.redis_key_prefix
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Create Redis connection if not provided."""
|
||||
if self._r is None:
|
||||
self._r = redis.from_url(
|
||||
settings.redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
await self._r.ping()
|
||||
logger.info("Connected to Redis at %s", settings.redis_url)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._r:
|
||||
await self._r.aclose()
|
||||
|
||||
@property
|
||||
def client(self) -> redis.Redis:
|
||||
if self._r is None:
|
||||
raise RuntimeError("Redis not connected. Call connect() first.")
|
||||
return self._r
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, *parts: str) -> str:
|
||||
return ":".join([self._prefix, *parts])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_session(self, session: SessionState) -> None:
|
||||
key = self._key("session", session.session_id)
|
||||
await self.client.set(
|
||||
key,
|
||||
session.model_dump_json(),
|
||||
ex=settings.session_ttl_seconds,
|
||||
)
|
||||
await self.client.sadd(
|
||||
self._key("sessions", "index"), session.session_id
|
||||
)
|
||||
|
||||
async def get_session(self, session_id: str) -> SessionState | None:
|
||||
key = self._key("session", session_id)
|
||||
data = await self.client.get(key)
|
||||
if data:
|
||||
return SessionState.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def update_session(self, session: SessionState) -> None:
|
||||
key = self._key("session", session.session_id)
|
||||
await self.client.set(
|
||||
key,
|
||||
session.model_dump_json(),
|
||||
ex=settings.session_ttl_seconds,
|
||||
)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
pipe = self.client.pipeline()
|
||||
pipe.delete(self._key("session", session_id))
|
||||
pipe.delete(self._key("session", session_id, "artifacts"))
|
||||
pipe.delete(self._key("session", session_id, "events"))
|
||||
pipe.srem(self._key("sessions", "index"), session_id)
|
||||
results = await pipe.execute()
|
||||
return bool(results[0])
|
||||
|
||||
async def list_sessions(self) -> list[str]:
|
||||
members = await self.client.smembers(self._key("sessions", "index"))
|
||||
return list(members)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event log
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def append_event(
|
||||
self, session_id: str, event: dict[str, Any]
|
||||
) -> None:
|
||||
key = self._key("session", session_id, "events")
|
||||
length = await self.client.rpush(key, json.dumps(event))
|
||||
# Set TTL on first event
|
||||
if length == 1:
|
||||
await self.client.expire(key, settings.session_ttl_seconds)
|
||||
# Trim only when significantly over cap (avoids race conditions)
|
||||
if length > 600:
|
||||
await self.client.ltrim(key, -500, -1)
|
||||
|
||||
async def get_events(
|
||||
self, session_id: str, start: int = 0, end: int = -1
|
||||
) -> list[dict[str, Any]]:
|
||||
key = self._key("session", session_id, "events")
|
||||
raw = await self.client.lrange(key, start, end)
|
||||
return [json.loads(r) for r in raw]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution lock (prevents concurrent messages on same session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_lock(
|
||||
self, session_id: str, timeout: int = 300
|
||||
) -> AsyncIterator[bool]:
|
||||
"""Acquire an exclusive execution lock for a session.
|
||||
|
||||
Uses SETNX with auto-expiry to prevent deadlocks if the process
|
||||
crashes mid-execution.
|
||||
|
||||
Usage:
|
||||
async with storage.session_lock(session_id) as acquired:
|
||||
if not acquired:
|
||||
raise HTTPException(409, "Session busy")
|
||||
# ... execute ...
|
||||
"""
|
||||
key = self._key("session", session_id, "lock")
|
||||
acquired = await self.client.set(key, "1", nx=True, ex=timeout)
|
||||
try:
|
||||
yield bool(acquired)
|
||||
finally:
|
||||
if acquired:
|
||||
await self.client.delete(key)
|
||||
3
src/streaming/__init__.py
Normal file
3
src/streaming/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sse import SSEEmitter, EventType
|
||||
|
||||
__all__ = ["SSEEmitter", "EventType"]
|
||||
BIN
src/streaming/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/streaming/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/streaming/__pycache__/sse.cpython-312.pyc
Normal file
BIN
src/streaming/__pycache__/sse.cpython-312.pyc
Normal file
Binary file not shown.
186
src/streaming/sse.py
Normal file
186
src/streaming/sse.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""SSE (Server-Sent Events) streaming system.
|
||||
|
||||
Supports structured event types, per-session event queues,
|
||||
and Redis-backed persistence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from ..config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
SESSION_CREATED = "session.created"
|
||||
EXECUTION_STARTED = "execution.started"
|
||||
AGENT_DELTA = "agent.delta"
|
||||
TOOL_STARTED = "tool.started"
|
||||
TOOL_COMPLETED = "tool.completed"
|
||||
SUBAGENT_ASSIGNED = "subagent.assigned"
|
||||
EXECUTION_COMPLETED = "execution.completed"
|
||||
ERROR = "error"
|
||||
KEEPALIVE = "keepalive"
|
||||
|
||||
|
||||
class SSEEvent:
|
||||
"""A single SSE event."""
|
||||
|
||||
__slots__ = ("event_type", "data", "session_id", "timestamp", "event_id")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_type: EventType,
|
||||
data: dict[str, Any],
|
||||
session_id: str,
|
||||
event_id: str = "",
|
||||
) -> None:
|
||||
self.event_type = event_type
|
||||
self.data = data
|
||||
self.session_id = session_id
|
||||
self.timestamp = datetime.now(timezone.utc).isoformat()
|
||||
self.event_id = event_id or f"{int(time.time() * 1000)}"
|
||||
|
||||
def format_sse(self) -> str:
|
||||
"""Format as SSE wire protocol."""
|
||||
payload = {
|
||||
"type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
lines = [
|
||||
f"id: {self.event_id}",
|
||||
f"event: {self.event_type.value}",
|
||||
f"data: {json.dumps(payload)}",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"event_id": self.event_id,
|
||||
"type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
class SSEEmitter:
|
||||
"""Manages per-session SSE event queues with Redis persistence."""
|
||||
|
||||
def __init__(self, redis_storage=None) -> None:
|
||||
self._queues: dict[str, list[asyncio.Queue[SSEEvent | None]]] = {}
|
||||
self._history: dict[str, list[SSEEvent]] = {}
|
||||
self._max_history = 500
|
||||
self._storage = redis_storage
|
||||
|
||||
def set_storage(self, redis_storage) -> None:
|
||||
"""Set the Redis storage backend (called after startup)."""
|
||||
self._storage = redis_storage
|
||||
|
||||
async def emit(
|
||||
self,
|
||||
event_type: EventType,
|
||||
data: dict[str, Any],
|
||||
session_id: str,
|
||||
) -> None:
|
||||
"""Emit an event to all listeners of a session."""
|
||||
event = SSEEvent(
|
||||
event_type=event_type,
|
||||
data=data,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Store in memory
|
||||
if session_id not in self._history:
|
||||
self._history[session_id] = []
|
||||
history = self._history[session_id]
|
||||
history.append(event)
|
||||
if len(history) > self._max_history:
|
||||
self._history[session_id] = history[-self._max_history:]
|
||||
|
||||
# Persist to Redis (skip keepalives — they're noise)
|
||||
if self._storage and event_type != EventType.KEEPALIVE:
|
||||
try:
|
||||
await self._storage.append_event(session_id, event.to_dict())
|
||||
logger.debug("Persisted event %s for session %s", event_type.value, session_id[:8])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist event to Redis: %s — storage type: %s", e, type(self._storage).__name__)
|
||||
|
||||
# Push to all active queues for this session
|
||||
queues = self._queues.get(session_id, [])
|
||||
for q in queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("SSE queue full for session %s", session_id)
|
||||
|
||||
async def subscribe(
|
||||
self, session_id: str
|
||||
) -> AsyncIterator[str]:
|
||||
"""Subscribe to SSE events for a session.
|
||||
|
||||
Yields formatted SSE strings. Sends keepalives to detect
|
||||
disconnected clients.
|
||||
"""
|
||||
queue: asyncio.Queue[SSEEvent | None] = asyncio.Queue(maxsize=256)
|
||||
|
||||
if session_id not in self._queues:
|
||||
self._queues[session_id] = []
|
||||
self._queues[session_id].append(queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(
|
||||
queue.get(),
|
||||
timeout=settings.sse_keepalive_seconds,
|
||||
)
|
||||
if event is None:
|
||||
break
|
||||
yield event.format_sse()
|
||||
except asyncio.TimeoutError:
|
||||
# Send keepalive
|
||||
keepalive = SSEEvent(
|
||||
event_type=EventType.KEEPALIVE,
|
||||
data={},
|
||||
session_id=session_id,
|
||||
)
|
||||
yield keepalive.format_sse()
|
||||
finally:
|
||||
if queue in self._queues.get(session_id, []):
|
||||
self._queues[session_id].remove(queue)
|
||||
|
||||
async def get_history(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Return event history — from Redis if available, fallback to memory."""
|
||||
# Try Redis first (persistent)
|
||||
if self._storage:
|
||||
try:
|
||||
events = await self._storage.get_events(session_id)
|
||||
if events:
|
||||
return events
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read events from Redis: %s", e)
|
||||
|
||||
# Fallback to in-memory
|
||||
events = self._history.get(session_id, [])
|
||||
return [e.to_dict() for e in events]
|
||||
|
||||
def cleanup_session(self, session_id: str) -> None:
|
||||
"""Remove all queues and in-memory history for a session."""
|
||||
for q in self._queues.get(session_id, []):
|
||||
try:
|
||||
q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
self._queues.pop(session_id, None)
|
||||
self._history.pop(session_id, None)
|
||||
Reference in New Issue
Block a user