"""Orchestrator Engine — single-agent execution. Flow: message → selected agent (with tools) → response The agent is determined by the session's agent_id via AgentRegistry. """ from __future__ import annotations import asyncio import logging import re from typing import Any from ..adapters.base import ModelAdapter from ..config import settings from ..context.engine import ContextEngine from ..context.compactor import estimate_tokens from ..mcp.manager import MCPManager from ..memory.store import MemoryStore from ..models.agent import AgentProfile from ..models.session import SessionState, SessionStatus, TaskStatus from ..streaming.sse import SSEEmitter, EventType from .agents.base import BaseAgent logger = logging.getLogger(__name__) class OrchestratorEngine: """Drives execution for a session message with the selected agent.""" def __init__( self, model_adapter: ModelAdapter, context_engine: ContextEngine, mcp_client: MCPManager, memory_store: MemoryStore, sse_emitter: SSEEmitter, agent_profile: AgentProfile, ) -> None: self.model = model_adapter self.context = context_engine self.mcp = mcp_client self.memory = memory_store self.sse = sse_emitter self.agent_profile = agent_profile # ------------------------------------------------------------------ # Public # ------------------------------------------------------------------ async def process_message( self, session: SessionState, message: str, ) -> dict[str, Any]: """Process a user message. Single agent execution with timeout.""" try: return await asyncio.wait_for( self._run(session, message), timeout=settings.max_execution_timeout_seconds, ) except asyncio.TimeoutError: logger.error("Execution timed out for session %s", session.session_id) if session.current_task: session.current_task.mark_failed("Execution timed out") session.status = SessionStatus.ERROR await self.sse.emit( EventType.ERROR, {"error": "execution_timeout", "message": "Task exceeded maximum execution time"}, session_id=session.session_id, ) return self._error_result(session, "Execution timed out") except Exception as e: logger.exception("Unhandled error for session %s", session.session_id) if session.current_task: session.current_task.mark_failed(str(e)) session.status = SessionStatus.ERROR await self.sse.emit( EventType.ERROR, {"error": "pipeline_error", "message": str(e)}, session_id=session.session_id, ) return self._error_result(session, str(e)) async def _run( self, session: SessionState, message: str, ) -> dict[str, Any]: """Execute: message → agent → response.""" await self.sse.emit( EventType.EXECUTION_STARTED, { "session_id": session.session_id, "agent_id": session.agent_id, "message": message[:200], }, session_id=session.session_id, ) # Create task task = session.begin_task(objective=message) task.status = TaskStatus.EXECUTING # Execute with the selected agent agent = BaseAgent( profile=self.agent_profile, model_adapter=self.model, context_engine=self.context, mcp_client=self.mcp, memory_store=self.memory, sse_emitter=self.sse, ) try: result = await agent.execute( session=session, max_steps=settings.subagent_max_steps, ) except Exception as e: logger.error("Execution failed: %s", e) task.mark_failed(str(e)) session.status = SessionStatus.ERROR await self.sse.emit( EventType.ERROR, {"error": "execution_failed", "message": str(e)}, session_id=session.session_id, ) return self._error_result(session, str(e)) # Compact to history content = result.get("content", "") usage = result.get("usage", {"input_tokens": 0, "output_tokens": 0}) key_data = self._extract_key_data_from_results([result]) session.recent_messages = self._append_recent_messages( session.recent_messages, message=message, conversation=result.get("conversation", []), ) session.task_history.append( self._build_task_history_entry( task_id=task.task_id, message=message, content=content, agent_id=session.agent_id, facts=task.facts_extracted, key_data=key_data, tool_executions=result.get("tool_executions", []), artifacts_count=len(result.get("artifacts", [])), ) ) session.task_history = self._trim_task_history(session.task_history) # Clean old artifacts artifacts = await self.memory.list_artifacts(session.session_id) recent_task_ids = {t["task_id"] for t in session.task_history[-2:]} for artifact in artifacts: if artifact.task_id not in recent_task_ids: key = f"{self.memory._prefix}:session:{session.session_id}:artifacts" await self.memory._r.hdel(key, artifact.artifact_id) # Complete task.status = TaskStatus.COMPLETED session.complete_task() # Calculate cost total_input = usage.get("input_tokens", 0) total_output = usage.get("output_tokens", 0) cost_usd = ( (total_input / 1_000_000) * settings.cost_per_1m_input + (total_output / 1_000_000) * settings.cost_per_1m_output ) await self.sse.emit( EventType.EXECUTION_COMPLETED, { "session_id": session.session_id, "task_id": task.task_id, "agent_id": session.agent_id, "steps_completed": 1, "steps_failed": [], "status": "completed", "usage": usage, "total_cost_usd": round(cost_usd, 6), }, session_id=session.session_id, ) logger.info( "Task %s completed (agent=%s, %d tools, %d artifacts, %d input tokens)", task.task_id, session.agent_id, len(result.get("tool_executions", [])), len(result.get("artifacts", [])), total_input, ) return { "session_id": session.session_id, "task_id": task.task_id, "agent_id": session.agent_id, "content": content or "Task completed.", "steps_completed": 1, "steps_failed": [], "artifacts_count": len(result.get("artifacts", [])), "review": "", "status": "completed", "usage": usage, "total_cost_usd": round(cost_usd, 6), } def _error_result(self, session: SessionState, error: str) -> dict[str, Any]: task_id = session.current_task.task_id if session.current_task else "none" return { "session_id": session.session_id, "task_id": task_id, "content": f"Error: {error}", "steps_completed": 0, "steps_failed": [], "artifacts_count": 0, "review": "", "status": "error", } @staticmethod def _append_recent_messages( existing: list[dict[str, Any]], message: str, conversation: list[dict[str, Any]], ) -> list[dict[str, Any]]: merged = [OrchestratorEngine._sanitize_recent_message(m) for m in existing] merged = [m for m in merged if m] current_turn: list[dict[str, Any]] = [] if message.strip(): current_turn.append({"role": "user", "content": message}) for message_obj in conversation: sanitized = OrchestratorEngine._sanitize_recent_message(message_obj) if sanitized: current_turn.append(sanitized) merged.extend(current_turn) return merged @staticmethod def _sanitize_recent_message(message: dict[str, Any]) -> dict[str, Any]: role = str(message.get("role", "")).strip() if role not in {"user", "assistant", "tool"}: return {} sanitized: dict[str, Any] = {"role": role} content = message.get("content", "") if isinstance(content, str) and content: sanitized["content"] = content if role == "assistant": tool_calls = message.get("tool_calls") if isinstance(tool_calls, list) and tool_calls: sanitized["tool_calls"] = tool_calls if role == "tool": tool_call_id = str(message.get("tool_call_id", "")).strip() if tool_call_id: sanitized["tool_call_id"] = tool_call_id if "content" not in sanitized and "tool_calls" not in sanitized: return {} return sanitized @staticmethod def _extract_key_data_from_results(results: list[dict[str, Any]]) -> dict[str, Any]: """Extract structured data from tool executions for task history.""" key_data: dict[str, Any] = {} seen_tables: dict[str, list[int]] = {} seen_sections: list[str] = [] seen_modules: list[str] = [] for result in results: for te in result.get("tool_executions", []): args = te.arguments table = args.get("tableName", "") record = args.get("recordNum") if table and record: record_int = int(record) if str(record).isdigit() else None if record_int: seen_tables.setdefault(table, []) if record_int not in seen_tables[table]: seen_tables[table].append(record_int) section = args.get("sectionId", "") if section and section not in seen_sections: seen_sections.append(section) module = args.get("moduleId", "") or args.get("moduleName", "") if module and module not in seen_modules: seen_modules.append(module) if seen_tables: key_data["tables"] = {t: nums[:10] for t, nums in seen_tables.items()} if seen_sections: key_data["sections"] = seen_sections[:20] if seen_modules: key_data["modules"] = seen_modules[:20] return key_data @staticmethod def _build_task_history_entry( task_id: str, message: str, content: str, agent_id: str, facts: list[str], key_data: dict[str, Any], tool_executions: list[Any], artifacts_count: int, ) -> dict[str, Any]: message_summary = " ".join(message.strip().split())[:120] content_summary = " ".join(content.strip().split())[:160] if content_summary: summary = f"User: {message_summary} → Agent: {content_summary}" else: summary = f"User: {message_summary}" outcomes = OrchestratorEngine._extract_outcomes(content) focus_refs = OrchestratorEngine._extract_focus_refs( message=message, content=content, key_data=key_data, outcomes=outcomes, ) tools_used: list[str] = [] for tool_exec in tool_executions: tool_name = getattr(tool_exec, "tool_name", "") if tool_name and tool_name not in tools_used: tools_used.append(tool_name) return { "task_id": task_id, "objective": message[:200], "agent_id": agent_id, "status": "completed", "steps": 1, "facts": facts[-5:], "key_data": key_data, "tools_used": tools_used[:8], "artifacts_count": artifacts_count, "summary": summary, "outcomes": outcomes, "focus_refs": focus_refs, "review": "", } @staticmethod def _trim_task_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]: if not history: return [] trimmed = history[-settings.task_history_max_entries:] kept: list[dict[str, Any]] = [] total_tokens = 0 for entry in reversed(trimmed): entry_tokens = OrchestratorEngine._estimate_task_history_entry_tokens(entry) if kept and total_tokens + entry_tokens > settings.task_history_max_tokens: break kept.append(entry) total_tokens += entry_tokens return list(reversed(kept)) @staticmethod def _estimate_task_history_entry_tokens(entry: dict[str, Any]) -> int: parts = [ entry.get("objective", ""), entry.get("summary", ""), " ".join(entry.get("facts", [])[:5]), " ".join(entry.get("tools_used", [])[:5]), str(entry.get("key_data", {})), " ".join(entry.get("outcomes", [])[:3]), str(entry.get("focus_refs", [])[:3]), ] return estimate_tokens("\n".join(p for p in parts if p)) @staticmethod def _extract_outcomes(content: str) -> list[str]: if not content: return [] normalized_lines = [] for raw_line in content.splitlines(): line = raw_line.strip() if not line: continue line = re.sub(r"^[#>\-\*\d\.\)\s]+", "", line).strip() if not line: continue normalized_lines.append(line) keywords = ( "si tuviera que elegir", "más flojo", "mas flojo", "más problem", "mas problem", "recomiendo", "recomendación", "recomendacion", "prioridad", "conclus", "debería", "deberia", "peor", "más débil", "mas debil", ) outcomes: list[str] = [] seen: set[str] = set() for line in normalized_lines: lower = line.lower() if any(k in lower for k in keywords): trimmed = line[:220] if trimmed not in seen: seen.add(trimmed) outcomes.append(trimmed) if len(outcomes) >= 3: return outcomes for line in normalized_lines: if len(line) < 20: continue trimmed = line[:180] if trimmed not in seen: seen.add(trimmed) outcomes.append(trimmed) if len(outcomes) >= 2: break return outcomes[:3] @staticmethod def _extract_focus_refs( message: str, content: str, key_data: dict[str, Any], outcomes: list[str], ) -> list[dict[str, str]]: refs: list[dict[str, str]] = [] seen: set[tuple[str, str, str]] = set() def add_ref(ref_type: str, label: str, ref_id: str = "", role: str = "related") -> None: label = label.strip() ref_id = ref_id.strip() if not label and not ref_id: return key = (ref_type, label, ref_id) if key in seen: return seen.add(key) refs.append({ "type": ref_type, "label": label or ref_id, "id": ref_id, "role": role, }) for table, nums in key_data.get("tables", {}).items(): add_ref("table", table, table, "related") for num in nums[:3]: add_ref("record", f"{table} record {num}", f"{table}:{num}", "related") for section in key_data.get("sections", [])[:5]: add_ref("section", section, section, "related") for module in key_data.get("modules", [])[:5]: add_ref("module", module, module, "related") source_text = "\n".join(outcomes + [content[:1200]]) for line in outcomes: for match in re.findall(r"\*\*([^*]{2,80})\*\*", line): add_ref( OrchestratorEngine._infer_ref_type(match, line, message), match, "", "primary_focus", ) if not any(ref["role"] == "primary_focus" for ref in refs): for pattern in ( r"(?:elegir(?:\s+\*\*uno\*\*)?,?\s+dir[ií]a que\s+\*\*([^*]{2,80})\*\*)", r"(?:el [^.\n]{0,40}m[aá]s flojo(?:[^.\n]{0,40})es\s+\*\*([^*]{2,80})\*\*)", ): match = re.search(pattern, source_text, flags=re.IGNORECASE) if match: label = match.group(1).strip() add_ref( OrchestratorEngine._infer_ref_type(label, source_text, message), label, "", "primary_focus", ) break return refs[:8] @staticmethod def _infer_ref_type(label: str, context: str, message: str) -> str: text = f"{label} {context} {message}".lower() if any(k in text for k in ("módulo", "modulo")): return "module" if any(k in text for k in ("página", "pagina", "apartado")): return "page" if "tabla" in text: return "table" if any(k in text for k in ("archivo", "file", ".tpl", ".php", ".js", ".css")): return "file" if any(k in text for k in ("sección", "seccion", "section")): return "section" return "entity"