Initial commit
This commit is contained in:
295
src/orchestrator/engine.py
Normal file
295
src/orchestrator/engine.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""Orchestrator Engine — the main execution loop.
|
||||
|
||||
Flow: message → plan → route → execute steps → summarize → update → stream
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from ..adapters.base import ModelAdapter
|
||||
from ..config import settings
|
||||
from ..context.engine import ContextEngine
|
||||
from ..mcp.client import MCPClient
|
||||
from ..memory.store import MemoryStore
|
||||
from ..models.agent import AgentRole
|
||||
from ..models.session import SessionState, SessionStatus, TaskStatus
|
||||
from ..streaming.sse import SSEEmitter, EventType
|
||||
from .agents.coder import CoderAgent, create_coder_profile
|
||||
from .agents.collector import CollectorAgent, create_collector_profile
|
||||
from .agents.planner import PlannerAgent, create_planner_profile
|
||||
from .agents.reviewer import ReviewerAgent, create_reviewer_profile
|
||||
from .router import route_step
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OrchestratorEngine:
|
||||
"""Drives the full execution lifecycle for a session message."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_adapter: ModelAdapter,
|
||||
context_engine: ContextEngine,
|
||||
mcp_client: MCPClient,
|
||||
memory_store: MemoryStore,
|
||||
sse_emitter: SSEEmitter,
|
||||
) -> None:
|
||||
self.model = model_adapter
|
||||
self.context = context_engine
|
||||
self.mcp = mcp_client
|
||||
self.memory = memory_store
|
||||
self.sse = sse_emitter
|
||||
|
||||
# Pre-built agent profiles
|
||||
self._profiles = {
|
||||
AgentRole.PLANNER: create_planner_profile(),
|
||||
AgentRole.CODER: create_coder_profile(),
|
||||
AgentRole.COLLECTOR: create_collector_profile(),
|
||||
AgentRole.REVIEWER: create_reviewer_profile(),
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session: SessionState,
|
||||
message: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Process a user message through the full orchestration pipeline.
|
||||
|
||||
Pipeline: plan → execute steps → review → complete
|
||||
|
||||
Handles errors gracefully: failed steps are marked and skipped,
|
||||
the session always returns to idle/error — never stuck in executing.
|
||||
"""
|
||||
task = None
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._run_pipeline(session, message),
|
||||
timeout=settings.max_execution_timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Execution timed out for session %s", session.session_id)
|
||||
if session.current_task:
|
||||
session.current_task.mark_failed("Execution timed out")
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "execution_timeout", "message": "Task exceeded maximum execution time"},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, "Execution timed out")
|
||||
except Exception as e:
|
||||
logger.exception("Unhandled error in pipeline for session %s", session.session_id)
|
||||
if session.current_task:
|
||||
session.current_task.mark_failed(str(e))
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "pipeline_error", "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, str(e))
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
session: SessionState,
|
||||
message: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Inner pipeline — wrapped by process_message for error handling."""
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.EXECUTION_STARTED,
|
||||
{"session_id": session.session_id, "message": message[:200]},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
# 1. Create task from message
|
||||
task = session.begin_task(objective=message)
|
||||
|
||||
# 2. Plan
|
||||
task.status = TaskStatus.PLANNING
|
||||
try:
|
||||
planner = self._create_agent(AgentRole.PLANNER)
|
||||
plan_steps = await planner.plan(session)
|
||||
task.plan = plan_steps
|
||||
task.status = TaskStatus.EXECUTING
|
||||
except Exception as e:
|
||||
logger.error("Planning failed: %s", e)
|
||||
task.mark_failed(f"Planning failed: {e}")
|
||||
session.status = SessionStatus.ERROR
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "planning_failed", "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
return self._error_result(session, f"Planning failed: {e}")
|
||||
|
||||
logger.info(
|
||||
"Plan created with %d steps for task %s",
|
||||
len(plan_steps),
|
||||
task.task_id,
|
||||
)
|
||||
|
||||
# 3. Execute each step — failures are logged and skipped
|
||||
results: list[dict[str, Any]] = []
|
||||
failed_steps: list[int] = []
|
||||
|
||||
for i, step in enumerate(task.plan):
|
||||
if i >= settings.max_execution_steps:
|
||||
logger.warning("Max execution steps reached")
|
||||
break
|
||||
|
||||
task.current_step_index = i
|
||||
step.status = TaskStatus.EXECUTING
|
||||
step.started_at = datetime.now(timezone.utc)
|
||||
|
||||
role = route_step(step)
|
||||
agent = self._create_agent(role)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.SUBAGENT_ASSIGNED,
|
||||
{
|
||||
"step": i + 1,
|
||||
"total_steps": len(task.plan),
|
||||
"agent": role.value,
|
||||
"description": step.description,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
step_result = await agent.execute(
|
||||
session=session,
|
||||
max_steps=settings.subagent_max_steps,
|
||||
)
|
||||
results.append(step_result)
|
||||
|
||||
step.status = TaskStatus.COMPLETED
|
||||
step.completed_at = datetime.now(timezone.utc)
|
||||
step.result_summary = (step_result.get("content", ""))[:500]
|
||||
step.tools_used = [
|
||||
te.tool_name for te in step_result.get("tool_executions", [])
|
||||
]
|
||||
|
||||
for artifact in step_result.get("artifacts", []):
|
||||
task.facts_extracted.extend(artifact.facts[:5])
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Step %d failed: %s", i + 1, e)
|
||||
step.status = TaskStatus.FAILED
|
||||
step.completed_at = datetime.now(timezone.utc)
|
||||
step.result_summary = f"Error: {e}"
|
||||
failed_steps.append(i + 1)
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.ERROR,
|
||||
{"error": "step_failed", "step": i + 1, "message": str(e)},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
# Continue with next step — don't block the pipeline
|
||||
|
||||
# 4. Review (if plan had more than 1 step and at least one succeeded)
|
||||
review_result: dict[str, Any] = {}
|
||||
if len(task.plan) > 1 and results:
|
||||
task.status = TaskStatus.REVIEWING
|
||||
try:
|
||||
reviewer = self._create_agent(AgentRole.REVIEWER)
|
||||
review_result = await reviewer.execute(
|
||||
session=session,
|
||||
max_steps=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Review failed: %s", e)
|
||||
review_result = {"content": f"Review skipped due to error: {e}"}
|
||||
|
||||
# 5. Complete — session ALWAYS returns to idle
|
||||
session.complete_task()
|
||||
|
||||
final_content = self._assemble_response(results, review_result)
|
||||
status = "completed" if not failed_steps else "partial"
|
||||
|
||||
await self.sse.emit(
|
||||
EventType.EXECUTION_COMPLETED,
|
||||
{
|
||||
"session_id": session.session_id,
|
||||
"task_id": task.task_id,
|
||||
"steps_completed": len(results),
|
||||
"steps_failed": failed_steps,
|
||||
"status": status,
|
||||
},
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"task_id": task.task_id,
|
||||
"content": final_content,
|
||||
"steps_completed": len(results),
|
||||
"steps_failed": failed_steps,
|
||||
"artifacts_count": sum(
|
||||
len(r.get("artifacts", [])) for r in results
|
||||
),
|
||||
"review": review_result.get("content", ""),
|
||||
"status": status,
|
||||
}
|
||||
|
||||
def _error_result(self, session: SessionState, error: str) -> dict[str, Any]:
|
||||
"""Build a standardized error response."""
|
||||
task_id = session.current_task.task_id if session.current_task else "none"
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"task_id": task_id,
|
||||
"content": f"Error: {error}",
|
||||
"steps_completed": 0,
|
||||
"steps_failed": [],
|
||||
"artifacts_count": 0,
|
||||
"review": "",
|
||||
"status": "error",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Internals
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _create_agent(self, role: AgentRole) -> PlannerAgent | CoderAgent | CollectorAgent | ReviewerAgent:
|
||||
"""Instantiate a subagent for the given role."""
|
||||
profile = self._profiles[role]
|
||||
agent_cls = {
|
||||
AgentRole.PLANNER: PlannerAgent,
|
||||
AgentRole.CODER: CoderAgent,
|
||||
AgentRole.COLLECTOR: CollectorAgent,
|
||||
AgentRole.REVIEWER: ReviewerAgent,
|
||||
}[role]
|
||||
|
||||
return agent_cls(
|
||||
profile=profile,
|
||||
model_adapter=self.model,
|
||||
context_engine=self.context,
|
||||
mcp_client=self.mcp,
|
||||
memory_store=self.memory,
|
||||
sse_emitter=self.sse,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _assemble_response(
|
||||
results: list[dict[str, Any]],
|
||||
review_result: dict[str, Any],
|
||||
) -> str:
|
||||
"""Combine step results into a coherent final response."""
|
||||
parts: list[str] = []
|
||||
for i, r in enumerate(results):
|
||||
content = r.get("content", "").strip()
|
||||
if content:
|
||||
parts.append(f"### Step {i + 1}\n{content}")
|
||||
|
||||
if review_result.get("content"):
|
||||
parts.append(f"### Review\n{review_result['content']}")
|
||||
|
||||
return "\n\n".join(parts) if parts else "Task completed."
|
||||
Reference in New Issue
Block a user