diff --git a/src/api/routes.py b/src/api/routes.py index 0873870..1cc117e 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -68,13 +68,15 @@ def set_dependencies( context_engine: Any, memory_store: Any, sse_emitter: Any, - mcp_registry: Any, + claude_emitter: Any = None, + mcp_registry: Any = None, ) -> None: _deps["storage"] = storage _deps["model_adapter"] = model_adapter _deps["context_engine"] = context_engine _deps["memory_store"] = memory_store _deps["sse"] = sse_emitter + _deps["claude_sse"] = claude_emitter _deps["mcp_registry"] = mcp_registry @@ -207,22 +209,33 @@ async def _execute_and_persist(orchestrator, storage, session, message) -> dict[ # ------------------------------------------------------------------ @router.get("/sessions/{session_id}/stream") -async def stream_session(session_id: str) -> StreamingResponse: +async def stream_session(session_id: str, format: str = "native") -> StreamingResponse: storage = _get_storage() session = await storage.get_session(session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") - sse = _get_sse() + headers = { + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + } + if format == "claude": + claude_sse = _deps.get("claude_sse") + if not claude_sse: + raise HTTPException(status_code=501, detail="Claude format emitter not available") + return StreamingResponse( + claude_sse.subscribe(session_id), + media_type="text/event-stream", + headers=headers, + ) + + sse = _get_sse() return StreamingResponse( sse.subscribe(session_id), media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, + headers=headers, ) diff --git a/src/main.py b/src/main.py index 8f3d583..5fbcdef 100644 --- a/src/main.py +++ b/src/main.py @@ -27,6 +27,7 @@ from .mcp.registry import MCPRegistry from .memory.store import MemoryStore from .orchestrator.engine import OrchestratorEngine from .storage.redis import RedisStorage +from .streaming.claude_format import ClaudeFormatEmitter, DualEmitter from .streaming.sse import SSEEmitter logging.basicConfig( @@ -38,6 +39,8 @@ logger = logging.getLogger(__name__) # Global instances (initialized in lifespan) redis_storage = RedisStorage() sse_emitter = SSEEmitter(redis_storage=redis_storage) +claude_emitter = ClaudeFormatEmitter() +dual_emitter = DualEmitter(sse_emitter, claude_emitter) mcp_registry = MCPRegistry() @@ -48,7 +51,6 @@ async def lifespan(app: FastAPI): # 1. Connect Redis await redis_storage.connect() - sse_emitter.set_storage(redis_storage) # 2. Initialize model adapter if settings.default_model_provider == "openai": @@ -82,12 +84,14 @@ async def lifespan(app: FastAPI): mcp_registry.load_config() # 6. Wire dependencies (orchestrator is created per-message with session's MCP) + dual_emitter.set_storage(redis_storage) set_dependencies( storage=redis_storage, model_adapter=model_adapter, context_engine=context_engine, memory_store=memory_store, - sse_emitter=sse_emitter, + sse_emitter=dual_emitter, + claude_emitter=claude_emitter, mcp_registry=mcp_registry, ) diff --git a/src/orchestrator/agents/base.py b/src/orchestrator/agents/base.py index b65c339..b8dd7f9 100644 --- a/src/orchestrator/agents/base.py +++ b/src/orchestrator/agents/base.py @@ -115,7 +115,7 @@ class BaseAgent: } await self.sse.emit( EventType.TOOL_STARTED, - {"tool": chunk.tool_name, "step": step}, + {"tool": chunk.tool_name, "tool_call_id": chunk.tool_call_id, "step": step}, session_id=session.session_id, ) @@ -123,6 +123,17 @@ class BaseAgent: tool = active_tools.get(chunk.tool_call_id) if tool: tool["arguments"] += chunk.tool_arguments + await self.sse.emit( + EventType.AGENT_DELTA, + { + "agent": self.profile.role, + "delta": "", + "tool_arguments": chunk.tool_arguments, + "tool_call_id": chunk.tool_call_id, + "step": step, + }, + session_id=session.session_id, + ) if chunk.finish_reason == "tool_use" and chunk.tool_call_id: tool = active_tools.pop(chunk.tool_call_id, None) @@ -200,6 +211,7 @@ class BaseAgent: tool_name=tc["name"], arguments=tc.get("parsed_arguments", {}), artifacts=artifacts, + tool_call_id=tc["id"], ) tool_fingerprints[fp] = tool_exec tool_executions.append(tool_exec) @@ -253,6 +265,7 @@ class BaseAgent: tool_name: str, arguments: dict[str, Any], artifacts: list[ArtifactSummary], + tool_call_id: str = "", ) -> ToolExecution: """Execute a tool and summarise the result.""" exec_id = uuid.uuid4().hex[:12] @@ -299,6 +312,8 @@ class BaseAgent: "tool": tool_name, "status": "completed", "summary": artifact.summary[:200], + "raw_output": raw_output[:4000], + "tool_call_id": tool_call_id, }, session_id=session.session_id, ) @@ -311,7 +326,7 @@ class BaseAgent: await self.sse.emit( EventType.TOOL_COMPLETED, - {"tool": tool_name, "status": "failed", "error": str(e)}, + {"tool": tool_name, "status": "failed", "error": str(e), "tool_call_id": tool_call_id}, session_id=session.session_id, ) diff --git a/src/orchestrator/engine.py b/src/orchestrator/engine.py index 07d8b88..b89ffa8 100644 --- a/src/orchestrator/engine.py +++ b/src/orchestrator/engine.py @@ -223,18 +223,6 @@ class OrchestratorEngine: final_content = self._assemble_response(results, review_result) status = "completed" if not failed_steps else "partial" - await self.sse.emit( - EventType.EXECUTION_COMPLETED, - { - "session_id": session.session_id, - "task_id": task.task_id, - "steps_completed": len(results), - "steps_failed": failed_steps, - "status": status, - }, - session_id=session.session_id, - ) - # Accumulate token usage: planner + all steps + review total_input = planner_usage.get("input_tokens", 0) total_output = planner_usage.get("output_tokens", 0) @@ -250,6 +238,23 @@ class OrchestratorEngine: + (total_output / 1_000_000) * settings.cost_per_1m_output ) + await self.sse.emit( + EventType.EXECUTION_COMPLETED, + { + "session_id": session.session_id, + "task_id": task.task_id, + "steps_completed": len(results), + "steps_failed": failed_steps, + "status": status, + "usage": { + "input_tokens": total_input, + "output_tokens": total_output, + }, + "total_cost_usd": round(cost_usd, 6), + }, + session_id=session.session_id, + ) + return { "session_id": session.session_id, "task_id": task.task_id, diff --git a/src/streaming/claude_format.py b/src/streaming/claude_format.py new file mode 100644 index 0000000..8518a51 --- /dev/null +++ b/src/streaming/claude_format.py @@ -0,0 +1,321 @@ +"""Claude Code CLI compatible SSE format emitter. + +Translates agenticSystem native events into the exact format that +Claude Code CLI produces, so the frontend can consume them without +any changes. Used via ?format=claude on the stream endpoint. + +Wire format: data: {json}\n\n (no event: or id: fields) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from typing import Any, AsyncIterator + +from .sse import EventType, SSEEmitter + +logger = logging.getLogger(__name__) + + +class ClaudeFormatEmitter: + """Emits events in Claude Code CLI SSE format. + + Maintains per-session state to track block indices and + accumulate content for assistant snapshots. + """ + + def __init__(self) -> None: + self._queues: dict[str, list[asyncio.Queue[str | None]]] = {} + # Per-session state + self._block_counter: dict[str, int] = {} + self._text_block_open: dict[str, bool] = {} + self._text_block_index: dict[str, int] = {} + self._tool_block_index: dict[str, dict[str, int]] = {} # session -> {tool_call_id -> index} + self._content_blocks: dict[str, list[dict[str, Any]]] = {} + self._text_accumulator: dict[str, str] = {} + + def _next_index(self, session_id: str) -> int: + idx = self._block_counter.get(session_id, 0) + self._block_counter[session_id] = idx + 1 + return idx + + def _reset_session(self, session_id: str) -> None: + self._block_counter[session_id] = 0 + self._text_block_open[session_id] = False + self._text_block_index[session_id] = -1 + self._tool_block_index[session_id] = {} + self._content_blocks[session_id] = [] + self._text_accumulator[session_id] = "" + + def _push(self, session_id: str, payload: dict[str, Any]) -> None: + """Push a formatted line to all subscribers of a session.""" + line = f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" + for q in self._queues.get(session_id, []): + try: + q.put_nowait(line) + except asyncio.QueueFull: + logger.warning("Claude SSE queue full for session %s", session_id[:8]) + + def _close_text_block(self, session_id: str) -> None: + """Close the current open text block if any.""" + if self._text_block_open.get(session_id): + idx = self._text_block_index[session_id] + self._push(session_id, { + "type": "stream_event", + "event": {"type": "content_block_stop", "index": idx}, + }) + # Save accumulated text to content blocks + text = self._text_accumulator.get(session_id, "") + if text: + self._content_blocks.setdefault(session_id, []).append({ + "type": "text", "text": text, + }) + self._text_block_open[session_id] = False + self._text_accumulator[session_id] = "" + + def _open_text_block(self, session_id: str) -> None: + """Open a new text block.""" + idx = self._next_index(session_id) + self._text_block_index[session_id] = idx + self._text_block_open[session_id] = True + self._text_accumulator[session_id] = "" + self._push(session_id, { + "type": "stream_event", + "event": { + "type": "content_block_start", + "index": idx, + "content_block": {"type": "text", "text": ""}, + }, + }) + + def _build_assistant_snapshot(self, session_id: str) -> dict[str, Any]: + """Build assistant message snapshot for reconciliation.""" + blocks = list(self._content_blocks.get(session_id, [])) + return { + "type": "assistant", + "message": {"content": blocks}, + "error": False, + } + + async def emit( + self, + event_type: EventType, + data: dict[str, Any], + session_id: str, + ) -> None: + """Translate a native event into Claude Code CLI format.""" + + if event_type == EventType.EXECUTION_STARTED: + self._reset_session(session_id) + self._push(session_id, { + "type": "stream_event", + "event": {"type": "message_start"}, + }) + + elif event_type == EventType.AGENT_DELTA: + delta_text = data.get("delta", "") + tool_args = data.get("tool_arguments", "") + tool_call_id = data.get("tool_call_id", "") + + if delta_text: + # Text streaming + if not self._text_block_open.get(session_id): + self._open_text_block(session_id) + idx = self._text_block_index[session_id] + self._text_accumulator[session_id] = self._text_accumulator.get(session_id, "") + delta_text + self._push(session_id, { + "type": "stream_event", + "event": { + "type": "content_block_delta", + "index": idx, + "delta": {"type": "text_delta", "text": delta_text}, + }, + }) + + elif tool_args and tool_call_id: + # Tool input JSON streaming + tool_indices = self._tool_block_index.get(session_id, {}) + idx = tool_indices.get(tool_call_id) + if idx is not None: + self._push(session_id, { + "type": "stream_event", + "event": { + "type": "content_block_delta", + "index": idx, + "delta": {"type": "input_json_delta", "partial_json": tool_args}, + }, + }) + + elif event_type == EventType.TOOL_STARTED: + tool_name = data.get("tool", "unknown") + tool_call_id = data.get("tool_call_id", "") + + # Close open text block + self._close_text_block(session_id) + + # Open tool_use block + idx = self._next_index(session_id) + self._tool_block_index.setdefault(session_id, {})[tool_call_id] = idx + self._push(session_id, { + "type": "stream_event", + "event": { + "type": "content_block_start", + "index": idx, + "content_block": { + "type": "tool_use", + "name": tool_name, + "id": tool_call_id, + }, + }, + }) + + elif event_type == EventType.TOOL_COMPLETED: + tool_name = data.get("tool", "unknown") + tool_call_id = data.get("tool_call_id", "") + status = data.get("status", "completed") + raw_output = data.get("raw_output", data.get("summary", "")) + is_error = status == "failed" + + # Close tool_use block + tool_indices = self._tool_block_index.get(session_id, {}) + idx = tool_indices.get(tool_call_id) + if idx is not None: + self._push(session_id, { + "type": "stream_event", + "event": {"type": "content_block_stop", "index": idx}, + }) + + # Save tool_use to content blocks for snapshot + self._content_blocks.setdefault(session_id, []).append({ + "type": "tool_use", + "id": tool_call_id, + "name": tool_name, + "input": {}, + }) + + # Emit tool_result + content = data.get("error", raw_output) if is_error else raw_output + self._push(session_id, { + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": content[:4000] if isinstance(content, str) else str(content)[:4000], + "is_error": is_error, + }) + + # Emit assistant snapshot for reconciliation + self._push(session_id, self._build_assistant_snapshot(session_id)) + + elif event_type == EventType.EXECUTION_COMPLETED: + # Close any open text block + self._close_text_block(session_id) + + # Final assistant snapshot + self._push(session_id, self._build_assistant_snapshot(session_id)) + + # Result with usage + usage = data.get("usage", {}) + self._push(session_id, { + "type": "result", + "is_error": False, + "usage": { + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + "cache_read_input_tokens": 0, + "cache_creation_input_tokens": 0, + }, + "total_cost_usd": data.get("total_cost_usd", 0), + }) + + # Done + self._push(session_id, {"type": "done"}) + + elif event_type == EventType.ERROR: + error_msg = data.get("message", str(data.get("error", "Unknown error"))) + + # Close any open block + self._close_text_block(session_id) + + self._push(session_id, { + "type": "result", + "is_error": True, + "result": error_msg, + "usage": {"input_tokens": 0, "output_tokens": 0, "cache_read_input_tokens": 0, "cache_creation_input_tokens": 0}, + "total_cost_usd": 0, + }) + self._push(session_id, {"type": "done"}) + + # Ignore other event types (KEEPALIVE, SESSION_CREATED, SUBAGENT_ASSIGNED) + + async def subscribe(self, session_id: str) -> AsyncIterator[str]: + """Subscribe to Claude-format SSE events for a session.""" + queue: asyncio.Queue[str | None] = asyncio.Queue(maxsize=512) + + if session_id not in self._queues: + self._queues[session_id] = [] + self._queues[session_id].append(queue) + + try: + while True: + try: + line = await asyncio.wait_for(queue.get(), timeout=15.0) + if line is None: + break + yield line + except asyncio.TimeoutError: + yield 'data: {"type":"keepalive"}\n\n' + finally: + if queue in self._queues.get(session_id, []): + self._queues[session_id].remove(queue) + + def cleanup_session(self, session_id: str) -> None: + """Clean up session state and close subscribers.""" + for q in self._queues.get(session_id, []): + try: + q.put_nowait(None) + except asyncio.QueueFull: + pass + self._queues.pop(session_id, None) + self._block_counter.pop(session_id, None) + self._text_block_open.pop(session_id, None) + self._text_block_index.pop(session_id, None) + self._tool_block_index.pop(session_id, None) + self._content_blocks.pop(session_id, None) + self._text_accumulator.pop(session_id, None) + + +class DualEmitter: + """Wraps SSEEmitter (native) + ClaudeFormatEmitter. + + Agents call emit() and both formats are produced. + Duck-type compatible with SSEEmitter. + """ + + def __init__(self, native: SSEEmitter, claude: ClaudeFormatEmitter) -> None: + self.native = native + self.claude = claude + + async def emit( + self, + event_type: EventType, + data: dict[str, Any], + session_id: str, + ) -> None: + await self.native.emit(event_type, data, session_id) + await self.claude.emit(event_type, data, session_id) + + # Delegate native SSE methods for backward compatibility + async def subscribe(self, session_id: str) -> AsyncIterator[str]: + async for line in self.native.subscribe(session_id): + yield line + + async def get_history(self, session_id: str) -> list[dict[str, Any]]: + return await self.native.get_history(session_id) + + def cleanup_session(self, session_id: str) -> None: + self.native.cleanup_session(session_id) + self.claude.cleanup_session(session_id) + + def set_storage(self, redis_storage: Any) -> None: + self.native.set_storage(redis_storage)