Initial commit
This commit is contained in:
241
src/orchestrator/agents/base.py
Normal file
241
src/orchestrator/agents/base.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Base subagent class with shared execution logic."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from ...adapters.base import ModelAdapter, ModelConfig, StreamChunk
|
||||
from ...context.engine import ContextEngine
|
||||
from ...mcp.client import MCPClient
|
||||
from ...memory.store import MemoryStore
|
||||
from ...models.agent import AgentProfile
|
||||
from ...models.artifacts import ArtifactSummary
|
||||
from ...models.session import SessionState
|
||||
from ...models.tools import ToolExecution, ToolExecutionStatus
|
||||
from ...streaming.sse import SSEEmitter, EventType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseAgent:
|
||||
"""Base class for all subagents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
profile: AgentProfile,
|
||||
model_adapter: ModelAdapter,
|
||||
context_engine: ContextEngine,
|
||||
mcp_client: MCPClient,
|
||||
memory_store: MemoryStore,
|
||||
sse_emitter: SSEEmitter,
|
||||
) -> None:
|
||||
self.profile = profile
|
||||
self.model = model_adapter
|
||||
self.context = context_engine
|
||||
self.mcp = mcp_client
|
||||
self.memory = memory_store
|
||||
self.sse = sse_emitter
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
session: SessionState,
|
||||
max_steps: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Run the agent's execution loop.
|
||||
|
||||
Returns a result dict with keys: content, artifacts, tool_executions.
|
||||
"""
|
||||
artifacts: list[ArtifactSummary] = await self.memory.list_artifacts(
|
||||
session.session_id
|
||||
)
|
||||
tool_executions: list[ToolExecution] = []
|
||||
accumulated_content = ""
|
||||
working_items: list[dict[str, Any]] = []
|
||||
|
||||
for step in range(max_steps):
|
||||
# Build context — NEVER includes raw tool output
|
||||
ctx = await self.context.build_context(
|
||||
session=session,
|
||||
agent=self.profile,
|
||||
artifacts=artifacts,
|
||||
working_items=working_items,
|
||||
)
|
||||
|
||||
# Prepare tool definitions
|
||||
tool_defs = self._get_allowed_tools()
|
||||
|
||||
# Stream model response
|
||||
config = ModelConfig(
|
||||
model_id=self.profile.model_id or "",
|
||||
max_tokens=self.profile.max_tokens or 4096,
|
||||
temperature=self.profile.temperature or 0.3,
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
current_tool: dict[str, Any] = {}
|
||||
|
||||
async for chunk in self.model.stream(
|
||||
messages=ctx.to_messages(),
|
||||
tools=tool_defs if tool_defs else None,
|
||||
config=config,
|
||||
):
|
||||
if chunk.delta:
|
||||
full_text += chunk.delta
|
||||
await self.sse.emit(
|
||||
EventType.AGENT_DELTA,
|
||||
{
|
||||
"agent": self.profile.role,
|
||||
"delta": chunk.delta,
|
||||
"step": step,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
if chunk.tool_name and not current_tool.get("name"):
|
||||
current_tool = {
|
||||
"id": chunk.tool_call_id,
|
||||
"name": chunk.tool_name,
|
||||
"arguments": "",
|
||||
}
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_STARTED,
|
||||
{"tool": chunk.tool_name, "step": step},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
if chunk.tool_arguments and current_tool:
|
||||
current_tool["arguments"] += chunk.tool_arguments
|
||||
|
||||
if chunk.finish_reason == "tool_use" and current_tool.get("name"):
|
||||
# Parse arguments
|
||||
try:
|
||||
args = json.loads(current_tool["arguments"]) if current_tool["arguments"] else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
current_tool["parsed_arguments"] = args
|
||||
tool_calls.append(current_tool)
|
||||
current_tool = {}
|
||||
|
||||
if chunk.finish_reason == "end_turn":
|
||||
break
|
||||
|
||||
accumulated_content += full_text
|
||||
|
||||
# If no tool calls, we're done
|
||||
if not tool_calls:
|
||||
break
|
||||
|
||||
# Execute tool calls
|
||||
for tc in tool_calls:
|
||||
tool_exec = await self._execute_tool(
|
||||
session=session,
|
||||
tool_name=tc["name"],
|
||||
arguments=tc.get("parsed_arguments", {}),
|
||||
artifacts=artifacts,
|
||||
)
|
||||
tool_executions.append(tool_exec)
|
||||
|
||||
# Add summarised result to working context (NEVER raw)
|
||||
working_items.append({
|
||||
"role": "tool_result",
|
||||
"content": f"[{tc['name']}] {tool_exec.result_summary}",
|
||||
})
|
||||
|
||||
return {
|
||||
"content": accumulated_content,
|
||||
"artifacts": artifacts,
|
||||
"tool_executions": tool_executions,
|
||||
}
|
||||
|
||||
async def _execute_tool(
|
||||
self,
|
||||
session: SessionState,
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
artifacts: list[ArtifactSummary],
|
||||
) -> ToolExecution:
|
||||
"""Execute a tool and summarise the result."""
|
||||
exec_id = uuid.uuid4().hex[:12]
|
||||
tool_exec = ToolExecution(
|
||||
execution_id=exec_id,
|
||||
tool_name=tool_name,
|
||||
arguments=arguments,
|
||||
status=ToolExecutionStatus.RUNNING,
|
||||
)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
if self.mcp.is_running and tool_name in self.mcp.tools:
|
||||
result = await self.mcp.call_tool(tool_name, arguments)
|
||||
raw_output = self._extract_mcp_output(result)
|
||||
else:
|
||||
raw_output = f"Tool '{tool_name}' not available via MCP."
|
||||
|
||||
duration = (time.monotonic() - start) * 1000
|
||||
|
||||
# Summarise — raw output NEVER enters context
|
||||
task_id = session.current_task.task_id if session.current_task else "none"
|
||||
artifact = self.context.summarize_tool_output(
|
||||
tool_name=tool_name,
|
||||
raw_output=raw_output,
|
||||
session_id=session.session_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
# Store artifact
|
||||
await self.memory.store_artifact(session.session_id, artifact)
|
||||
artifacts.append(artifact)
|
||||
|
||||
tool_exec.status = ToolExecutionStatus.COMPLETED
|
||||
tool_exec.result_summary = artifact.summary
|
||||
tool_exec.duration_ms = duration
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_COMPLETED,
|
||||
{
|
||||
"tool": tool_name,
|
||||
"status": "completed",
|
||||
"summary": artifact.summary[:200],
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
tool_exec.status = ToolExecutionStatus.FAILED
|
||||
tool_exec.error = str(e)
|
||||
tool_exec.duration_ms = (time.monotonic() - start) * 1000
|
||||
logger.error("Tool execution failed: %s — %s", tool_name, e)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.TOOL_COMPLETED,
|
||||
{"tool": tool_name, "status": "failed", "error": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return tool_exec
|
||||
|
||||
def _get_allowed_tools(self) -> list[dict[str, Any]]:
|
||||
"""Return tool definitions filtered by this agent's allowed_tools."""
|
||||
if not self.mcp.is_running:
|
||||
return []
|
||||
all_tools = self.mcp.get_tool_definitions()
|
||||
if not self.profile.allowed_tools:
|
||||
return all_tools # No filter → all tools
|
||||
return [t for t in all_tools if t["name"] in self.profile.allowed_tools]
|
||||
|
||||
@staticmethod
|
||||
def _extract_mcp_output(result: dict[str, Any]) -> str:
|
||||
"""Extract text content from MCP tool result."""
|
||||
content = result.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
parts.append(item.get("text", ""))
|
||||
return "\n".join(parts) if parts else json.dumps(result)
|
||||
return str(content)
|
||||
Reference in New Issue
Block a user