Files
agenticSystem/src/orchestrator/engine.py
Jordan Diaz 237dc00379 nah
2026-04-09 20:46:03 +00:00

520 lines
18 KiB
Python

"""Orchestrator Engine — single-agent execution.
Flow: message → selected agent (with tools) → response
The agent is determined by the session's agent_id via AgentRegistry.
"""
from __future__ import annotations
import asyncio
import logging
import re
from typing import Any
from ..adapters.base import ModelAdapter
from ..config import settings
from ..context.engine import ContextEngine
from ..context.compactor import estimate_tokens
from ..mcp.manager import MCPManager
from ..memory.store import MemoryStore
from ..models.agent import AgentProfile
from ..models.session import SessionState, SessionStatus, TaskStatus
from ..streaming.sse import SSEEmitter, EventType
from .agents.base import BaseAgent
logger = logging.getLogger(__name__)
class OrchestratorEngine:
"""Drives execution for a session message with the selected agent."""
def __init__(
self,
model_adapter: ModelAdapter,
context_engine: ContextEngine,
mcp_client: MCPManager,
memory_store: MemoryStore,
sse_emitter: SSEEmitter,
agent_profile: AgentProfile,
) -> None:
self.model = model_adapter
self.context = context_engine
self.mcp = mcp_client
self.memory = memory_store
self.sse = sse_emitter
self.agent_profile = agent_profile
# ------------------------------------------------------------------
# Public
# ------------------------------------------------------------------
async def process_message(
self,
session: SessionState,
message: str,
) -> dict[str, Any]:
"""Process a user message. Single agent execution with timeout."""
try:
return await asyncio.wait_for(
self._run(session, message),
timeout=settings.max_execution_timeout_seconds,
)
except asyncio.TimeoutError:
logger.error("Execution timed out for session %s", session.session_id)
if session.current_task:
session.current_task.mark_failed("Execution timed out")
session.status = SessionStatus.ERROR
await self.sse.emit(
EventType.ERROR,
{"error": "execution_timeout", "message": "Task exceeded maximum execution time"},
session_id=session.session_id,
)
return self._error_result(session, "Execution timed out")
except Exception as e:
logger.exception("Unhandled error for session %s", session.session_id)
if session.current_task:
session.current_task.mark_failed(str(e))
session.status = SessionStatus.ERROR
await self.sse.emit(
EventType.ERROR,
{"error": "pipeline_error", "message": str(e)},
session_id=session.session_id,
)
return self._error_result(session, str(e))
async def _run(
self,
session: SessionState,
message: str,
) -> dict[str, Any]:
"""Execute: message → agent → response."""
await self.sse.emit(
EventType.EXECUTION_STARTED,
{
"session_id": session.session_id,
"agent_id": session.agent_id,
"message": message[:200],
},
session_id=session.session_id,
)
# Create task
task = session.begin_task(objective=message)
task.status = TaskStatus.EXECUTING
# Execute with the selected agent
agent = BaseAgent(
profile=self.agent_profile,
model_adapter=self.model,
context_engine=self.context,
mcp_client=self.mcp,
memory_store=self.memory,
sse_emitter=self.sse,
)
try:
result = await agent.execute(
session=session,
max_steps=settings.subagent_max_steps,
)
except Exception as e:
logger.error("Execution failed: %s", e)
task.mark_failed(str(e))
session.status = SessionStatus.ERROR
await self.sse.emit(
EventType.ERROR,
{"error": "execution_failed", "message": str(e)},
session_id=session.session_id,
)
return self._error_result(session, str(e))
# Compact to history
content = result.get("content", "")
usage = result.get("usage", {"input_tokens": 0, "output_tokens": 0})
key_data = self._extract_key_data_from_results([result])
session.recent_messages = self._append_recent_messages(
session.recent_messages,
message=message,
conversation=result.get("conversation", []),
)
session.task_history.append(
self._build_task_history_entry(
task_id=task.task_id,
message=message,
content=content,
agent_id=session.agent_id,
facts=task.facts_extracted,
key_data=key_data,
tool_executions=result.get("tool_executions", []),
artifacts_count=len(result.get("artifacts", [])),
)
)
session.task_history = self._trim_task_history(session.task_history)
# Clean old artifacts
artifacts = await self.memory.list_artifacts(session.session_id)
recent_task_ids = {t["task_id"] for t in session.task_history[-2:]}
for artifact in artifacts:
if artifact.task_id not in recent_task_ids:
key = f"{self.memory._prefix}:session:{session.session_id}:artifacts"
await self.memory._r.hdel(key, artifact.artifact_id)
# Complete
task.status = TaskStatus.COMPLETED
session.complete_task()
# Calculate cost
total_input = usage.get("input_tokens", 0)
total_output = usage.get("output_tokens", 0)
cost_usd = (
(total_input / 1_000_000) * settings.cost_per_1m_input
+ (total_output / 1_000_000) * settings.cost_per_1m_output
)
await self.sse.emit(
EventType.EXECUTION_COMPLETED,
{
"session_id": session.session_id,
"task_id": task.task_id,
"agent_id": session.agent_id,
"steps_completed": 1,
"steps_failed": [],
"status": "completed",
"usage": usage,
"total_cost_usd": round(cost_usd, 6),
},
session_id=session.session_id,
)
logger.info(
"Task %s completed (agent=%s, %d tools, %d artifacts, %d input tokens)",
task.task_id,
session.agent_id,
len(result.get("tool_executions", [])),
len(result.get("artifacts", [])),
total_input,
)
return {
"session_id": session.session_id,
"task_id": task.task_id,
"agent_id": session.agent_id,
"content": content or "Task completed.",
"steps_completed": 1,
"steps_failed": [],
"artifacts_count": len(result.get("artifacts", [])),
"review": "",
"status": "completed",
"usage": usage,
"total_cost_usd": round(cost_usd, 6),
}
def _error_result(self, session: SessionState, error: str) -> dict[str, Any]:
task_id = session.current_task.task_id if session.current_task else "none"
return {
"session_id": session.session_id,
"task_id": task_id,
"content": f"Error: {error}",
"steps_completed": 0,
"steps_failed": [],
"artifacts_count": 0,
"review": "",
"status": "error",
}
@staticmethod
def _append_recent_messages(
existing: list[dict[str, Any]],
message: str,
conversation: list[dict[str, Any]],
) -> list[dict[str, Any]]:
merged = [OrchestratorEngine._sanitize_recent_message(m) for m in existing]
merged = [m for m in merged if m]
current_turn: list[dict[str, Any]] = []
if message.strip():
current_turn.append({"role": "user", "content": message})
for message_obj in conversation:
sanitized = OrchestratorEngine._sanitize_recent_message(message_obj)
if sanitized:
current_turn.append(sanitized)
merged.extend(current_turn)
return merged
@staticmethod
def _sanitize_recent_message(message: dict[str, Any]) -> dict[str, Any]:
role = str(message.get("role", "")).strip()
if role not in {"user", "assistant", "tool"}:
return {}
sanitized: dict[str, Any] = {"role": role}
content = message.get("content", "")
if isinstance(content, str) and content:
sanitized["content"] = content
if role == "assistant":
tool_calls = message.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
sanitized["tool_calls"] = tool_calls
if role == "tool":
tool_call_id = str(message.get("tool_call_id", "")).strip()
if tool_call_id:
sanitized["tool_call_id"] = tool_call_id
if "content" not in sanitized and "tool_calls" not in sanitized:
return {}
return sanitized
@staticmethod
def _extract_key_data_from_results(results: list[dict[str, Any]]) -> dict[str, Any]:
"""Extract structured data from tool executions for task history."""
key_data: dict[str, Any] = {}
seen_tables: dict[str, list[int]] = {}
seen_sections: list[str] = []
seen_modules: list[str] = []
for result in results:
for te in result.get("tool_executions", []):
args = te.arguments
table = args.get("tableName", "")
record = args.get("recordNum")
if table and record:
record_int = int(record) if str(record).isdigit() else None
if record_int:
seen_tables.setdefault(table, [])
if record_int not in seen_tables[table]:
seen_tables[table].append(record_int)
section = args.get("sectionId", "")
if section and section not in seen_sections:
seen_sections.append(section)
module = args.get("moduleId", "") or args.get("moduleName", "")
if module and module not in seen_modules:
seen_modules.append(module)
if seen_tables:
key_data["tables"] = {t: nums[:10] for t, nums in seen_tables.items()}
if seen_sections:
key_data["sections"] = seen_sections[:20]
if seen_modules:
key_data["modules"] = seen_modules[:20]
return key_data
@staticmethod
def _build_task_history_entry(
task_id: str,
message: str,
content: str,
agent_id: str,
facts: list[str],
key_data: dict[str, Any],
tool_executions: list[Any],
artifacts_count: int,
) -> dict[str, Any]:
message_summary = " ".join(message.strip().split())[:120]
content_summary = " ".join(content.strip().split())[:160]
if content_summary:
summary = f"User: {message_summary} → Agent: {content_summary}"
else:
summary = f"User: {message_summary}"
outcomes = OrchestratorEngine._extract_outcomes(content)
focus_refs = OrchestratorEngine._extract_focus_refs(
message=message,
content=content,
key_data=key_data,
outcomes=outcomes,
)
tools_used: list[str] = []
for tool_exec in tool_executions:
tool_name = getattr(tool_exec, "tool_name", "")
if tool_name and tool_name not in tools_used:
tools_used.append(tool_name)
return {
"task_id": task_id,
"objective": message[:200],
"agent_id": agent_id,
"status": "completed",
"steps": 1,
"facts": facts[-5:],
"key_data": key_data,
"tools_used": tools_used[:8],
"artifacts_count": artifacts_count,
"summary": summary,
"outcomes": outcomes,
"focus_refs": focus_refs,
"review": "",
}
@staticmethod
def _trim_task_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
if not history:
return []
trimmed = history[-settings.task_history_max_entries:]
kept: list[dict[str, Any]] = []
total_tokens = 0
for entry in reversed(trimmed):
entry_tokens = OrchestratorEngine._estimate_task_history_entry_tokens(entry)
if kept and total_tokens + entry_tokens > settings.task_history_max_tokens:
break
kept.append(entry)
total_tokens += entry_tokens
return list(reversed(kept))
@staticmethod
def _estimate_task_history_entry_tokens(entry: dict[str, Any]) -> int:
parts = [
entry.get("objective", ""),
entry.get("summary", ""),
" ".join(entry.get("facts", [])[:5]),
" ".join(entry.get("tools_used", [])[:5]),
str(entry.get("key_data", {})),
" ".join(entry.get("outcomes", [])[:3]),
str(entry.get("focus_refs", [])[:3]),
]
return estimate_tokens("\n".join(p for p in parts if p))
@staticmethod
def _extract_outcomes(content: str) -> list[str]:
if not content:
return []
normalized_lines = []
for raw_line in content.splitlines():
line = raw_line.strip()
if not line:
continue
line = re.sub(r"^[#>\-\*\d\.\)\s]+", "", line).strip()
if not line:
continue
normalized_lines.append(line)
keywords = (
"si tuviera que elegir",
"más flojo",
"mas flojo",
"más problem",
"mas problem",
"recomiendo",
"recomendación",
"recomendacion",
"prioridad",
"conclus",
"debería",
"deberia",
"peor",
"más débil",
"mas debil",
)
outcomes: list[str] = []
seen: set[str] = set()
for line in normalized_lines:
lower = line.lower()
if any(k in lower for k in keywords):
trimmed = line[:220]
if trimmed not in seen:
seen.add(trimmed)
outcomes.append(trimmed)
if len(outcomes) >= 3:
return outcomes
for line in normalized_lines:
if len(line) < 20:
continue
trimmed = line[:180]
if trimmed not in seen:
seen.add(trimmed)
outcomes.append(trimmed)
if len(outcomes) >= 2:
break
return outcomes[:3]
@staticmethod
def _extract_focus_refs(
message: str,
content: str,
key_data: dict[str, Any],
outcomes: list[str],
) -> list[dict[str, str]]:
refs: list[dict[str, str]] = []
seen: set[tuple[str, str, str]] = set()
def add_ref(ref_type: str, label: str, ref_id: str = "", role: str = "related") -> None:
label = label.strip()
ref_id = ref_id.strip()
if not label and not ref_id:
return
key = (ref_type, label, ref_id)
if key in seen:
return
seen.add(key)
refs.append({
"type": ref_type,
"label": label or ref_id,
"id": ref_id,
"role": role,
})
for table, nums in key_data.get("tables", {}).items():
add_ref("table", table, table, "related")
for num in nums[:3]:
add_ref("record", f"{table} record {num}", f"{table}:{num}", "related")
for section in key_data.get("sections", [])[:5]:
add_ref("section", section, section, "related")
for module in key_data.get("modules", [])[:5]:
add_ref("module", module, module, "related")
source_text = "\n".join(outcomes + [content[:1200]])
for line in outcomes:
for match in re.findall(r"\*\*([^*]{2,80})\*\*", line):
add_ref(
OrchestratorEngine._infer_ref_type(match, line, message),
match,
"",
"primary_focus",
)
if not any(ref["role"] == "primary_focus" for ref in refs):
for pattern in (
r"(?:elegir(?:\s+\*\*uno\*\*)?,?\s+dir[ií]a que\s+\*\*([^*]{2,80})\*\*)",
r"(?:el [^.\n]{0,40}m[aá]s flojo(?:[^.\n]{0,40})es\s+\*\*([^*]{2,80})\*\*)",
):
match = re.search(pattern, source_text, flags=re.IGNORECASE)
if match:
label = match.group(1).strip()
add_ref(
OrchestratorEngine._infer_ref_type(label, source_text, message),
label,
"",
"primary_focus",
)
break
return refs[:8]
@staticmethod
def _infer_ref_type(label: str, context: str, message: str) -> str:
text = f"{label} {context} {message}".lower()
if any(k in text for k in ("módulo", "modulo")):
return "module"
if any(k in text for k in ("página", "pagina", "apartado")):
return "page"
if "tabla" in text:
return "table"
if any(k in text for k in ("archivo", "file", ".tpl", ".php", ".js", ".css")):
return "file"
if any(k in text for k in ("sección", "seccion", "section")):
return "section"
return "entity"