242 lines
8.4 KiB
Python
242 lines
8.4 KiB
Python
"""Base subagent class with shared execution logic."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Any, AsyncIterator
|
|
|
|
from ...adapters.base import ModelAdapter, ModelConfig, StreamChunk
|
|
from ...context.engine import ContextEngine
|
|
from ...mcp.client import MCPClient
|
|
from ...memory.store import MemoryStore
|
|
from ...models.agent import AgentProfile
|
|
from ...models.artifacts import ArtifactSummary
|
|
from ...models.session import SessionState
|
|
from ...models.tools import ToolExecution, ToolExecutionStatus
|
|
from ...streaming.sse import SSEEmitter, EventType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseAgent:
|
|
"""Base class for all subagents."""
|
|
|
|
def __init__(
|
|
self,
|
|
profile: AgentProfile,
|
|
model_adapter: ModelAdapter,
|
|
context_engine: ContextEngine,
|
|
mcp_client: MCPClient,
|
|
memory_store: MemoryStore,
|
|
sse_emitter: SSEEmitter,
|
|
) -> None:
|
|
self.profile = profile
|
|
self.model = model_adapter
|
|
self.context = context_engine
|
|
self.mcp = mcp_client
|
|
self.memory = memory_store
|
|
self.sse = sse_emitter
|
|
|
|
async def execute(
|
|
self,
|
|
session: SessionState,
|
|
max_steps: int = 10,
|
|
) -> dict[str, Any]:
|
|
"""Run the agent's execution loop.
|
|
|
|
Returns a result dict with keys: content, artifacts, tool_executions.
|
|
"""
|
|
artifacts: list[ArtifactSummary] = await self.memory.list_artifacts(
|
|
session.session_id
|
|
)
|
|
tool_executions: list[ToolExecution] = []
|
|
accumulated_content = ""
|
|
working_items: list[dict[str, Any]] = []
|
|
|
|
for step in range(max_steps):
|
|
# Build context — NEVER includes raw tool output
|
|
ctx = await self.context.build_context(
|
|
session=session,
|
|
agent=self.profile,
|
|
artifacts=artifacts,
|
|
working_items=working_items,
|
|
)
|
|
|
|
# Prepare tool definitions
|
|
tool_defs = self._get_allowed_tools()
|
|
|
|
# Stream model response
|
|
config = ModelConfig(
|
|
model_id=self.profile.model_id or "",
|
|
max_tokens=self.profile.max_tokens or 4096,
|
|
temperature=self.profile.temperature or 0.3,
|
|
)
|
|
|
|
full_text = ""
|
|
tool_calls: list[dict[str, Any]] = []
|
|
current_tool: dict[str, Any] = {}
|
|
|
|
async for chunk in self.model.stream(
|
|
messages=ctx.to_messages(),
|
|
tools=tool_defs if tool_defs else None,
|
|
config=config,
|
|
):
|
|
if chunk.delta:
|
|
full_text += chunk.delta
|
|
await self.sse.emit(
|
|
EventType.AGENT_DELTA,
|
|
{
|
|
"agent": self.profile.role,
|
|
"delta": chunk.delta,
|
|
"step": step,
|
|
},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
if chunk.tool_name and not current_tool.get("name"):
|
|
current_tool = {
|
|
"id": chunk.tool_call_id,
|
|
"name": chunk.tool_name,
|
|
"arguments": "",
|
|
}
|
|
await self.sse.emit(
|
|
EventType.TOOL_STARTED,
|
|
{"tool": chunk.tool_name, "step": step},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
if chunk.tool_arguments and current_tool:
|
|
current_tool["arguments"] += chunk.tool_arguments
|
|
|
|
if chunk.finish_reason == "tool_use" and current_tool.get("name"):
|
|
# Parse arguments
|
|
try:
|
|
args = json.loads(current_tool["arguments"]) if current_tool["arguments"] else {}
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
current_tool["parsed_arguments"] = args
|
|
tool_calls.append(current_tool)
|
|
current_tool = {}
|
|
|
|
if chunk.finish_reason == "end_turn":
|
|
break
|
|
|
|
accumulated_content += full_text
|
|
|
|
# If no tool calls, we're done
|
|
if not tool_calls:
|
|
break
|
|
|
|
# Execute tool calls
|
|
for tc in tool_calls:
|
|
tool_exec = await self._execute_tool(
|
|
session=session,
|
|
tool_name=tc["name"],
|
|
arguments=tc.get("parsed_arguments", {}),
|
|
artifacts=artifacts,
|
|
)
|
|
tool_executions.append(tool_exec)
|
|
|
|
# Add summarised result to working context (NEVER raw)
|
|
working_items.append({
|
|
"role": "tool_result",
|
|
"content": f"[{tc['name']}] {tool_exec.result_summary}",
|
|
})
|
|
|
|
return {
|
|
"content": accumulated_content,
|
|
"artifacts": artifacts,
|
|
"tool_executions": tool_executions,
|
|
}
|
|
|
|
async def _execute_tool(
|
|
self,
|
|
session: SessionState,
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
artifacts: list[ArtifactSummary],
|
|
) -> ToolExecution:
|
|
"""Execute a tool and summarise the result."""
|
|
exec_id = uuid.uuid4().hex[:12]
|
|
tool_exec = ToolExecution(
|
|
execution_id=exec_id,
|
|
tool_name=tool_name,
|
|
arguments=arguments,
|
|
status=ToolExecutionStatus.RUNNING,
|
|
)
|
|
|
|
start = time.monotonic()
|
|
try:
|
|
if self.mcp.is_running and tool_name in self.mcp.tools:
|
|
result = await self.mcp.call_tool(tool_name, arguments)
|
|
raw_output = self._extract_mcp_output(result)
|
|
else:
|
|
raw_output = f"Tool '{tool_name}' not available via MCP."
|
|
|
|
duration = (time.monotonic() - start) * 1000
|
|
|
|
# Summarise — raw output NEVER enters context
|
|
task_id = session.current_task.task_id if session.current_task else "none"
|
|
artifact = self.context.summarize_tool_output(
|
|
tool_name=tool_name,
|
|
raw_output=raw_output,
|
|
session_id=session.session_id,
|
|
task_id=task_id,
|
|
)
|
|
|
|
# Store artifact
|
|
await self.memory.store_artifact(session.session_id, artifact)
|
|
artifacts.append(artifact)
|
|
|
|
tool_exec.status = ToolExecutionStatus.COMPLETED
|
|
tool_exec.result_summary = artifact.summary
|
|
tool_exec.duration_ms = duration
|
|
|
|
await self.sse.emit(
|
|
EventType.TOOL_COMPLETED,
|
|
{
|
|
"tool": tool_name,
|
|
"status": "completed",
|
|
"summary": artifact.summary[:200],
|
|
},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
except Exception as e:
|
|
tool_exec.status = ToolExecutionStatus.FAILED
|
|
tool_exec.error = str(e)
|
|
tool_exec.duration_ms = (time.monotonic() - start) * 1000
|
|
logger.error("Tool execution failed: %s — %s", tool_name, e)
|
|
|
|
await self.sse.emit(
|
|
EventType.TOOL_COMPLETED,
|
|
{"tool": tool_name, "status": "failed", "error": str(e)},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
return tool_exec
|
|
|
|
def _get_allowed_tools(self) -> list[dict[str, Any]]:
|
|
"""Return tool definitions filtered by this agent's allowed_tools."""
|
|
if not self.mcp.is_running:
|
|
return []
|
|
all_tools = self.mcp.get_tool_definitions()
|
|
if not self.profile.allowed_tools:
|
|
return all_tools # No filter → all tools
|
|
return [t for t in all_tools if t["name"] in self.profile.allowed_tools]
|
|
|
|
@staticmethod
|
|
def _extract_mcp_output(result: dict[str, Any]) -> str:
|
|
"""Extract text content from MCP tool result."""
|
|
content = result.get("content", [])
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, dict) and item.get("type") == "text":
|
|
parts.append(item.get("text", ""))
|
|
return "\n".join(parts) if parts else json.dumps(result)
|
|
return str(content)
|