520 lines
18 KiB
Python
520 lines
18 KiB
Python
"""Orchestrator Engine — single-agent execution.
|
|
|
|
Flow: message → selected agent (with tools) → response
|
|
The agent is determined by the session's agent_id via AgentRegistry.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
from ..adapters.base import ModelAdapter
|
|
from ..config import settings
|
|
from ..context.engine import ContextEngine
|
|
from ..context.compactor import estimate_tokens
|
|
from ..mcp.manager import MCPManager
|
|
from ..memory.store import MemoryStore
|
|
from ..models.agent import AgentProfile
|
|
from ..models.session import SessionState, SessionStatus, TaskStatus
|
|
from ..streaming.sse import SSEEmitter, EventType
|
|
from .agents.base import BaseAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class OrchestratorEngine:
|
|
"""Drives execution for a session message with the selected agent."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_adapter: ModelAdapter,
|
|
context_engine: ContextEngine,
|
|
mcp_client: MCPManager,
|
|
memory_store: MemoryStore,
|
|
sse_emitter: SSEEmitter,
|
|
agent_profile: AgentProfile,
|
|
) -> None:
|
|
self.model = model_adapter
|
|
self.context = context_engine
|
|
self.mcp = mcp_client
|
|
self.memory = memory_store
|
|
self.sse = sse_emitter
|
|
self.agent_profile = agent_profile
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public
|
|
# ------------------------------------------------------------------
|
|
|
|
async def process_message(
|
|
self,
|
|
session: SessionState,
|
|
message: str,
|
|
) -> dict[str, Any]:
|
|
"""Process a user message. Single agent execution with timeout."""
|
|
try:
|
|
return await asyncio.wait_for(
|
|
self._run(session, message),
|
|
timeout=settings.max_execution_timeout_seconds,
|
|
)
|
|
except asyncio.TimeoutError:
|
|
logger.error("Execution timed out for session %s", session.session_id)
|
|
if session.current_task:
|
|
session.current_task.mark_failed("Execution timed out")
|
|
session.status = SessionStatus.ERROR
|
|
await self.sse.emit(
|
|
EventType.ERROR,
|
|
{"error": "execution_timeout", "message": "Task exceeded maximum execution time"},
|
|
session_id=session.session_id,
|
|
)
|
|
return self._error_result(session, "Execution timed out")
|
|
except Exception as e:
|
|
logger.exception("Unhandled error for session %s", session.session_id)
|
|
if session.current_task:
|
|
session.current_task.mark_failed(str(e))
|
|
session.status = SessionStatus.ERROR
|
|
await self.sse.emit(
|
|
EventType.ERROR,
|
|
{"error": "pipeline_error", "message": str(e)},
|
|
session_id=session.session_id,
|
|
)
|
|
return self._error_result(session, str(e))
|
|
|
|
async def _run(
|
|
self,
|
|
session: SessionState,
|
|
message: str,
|
|
) -> dict[str, Any]:
|
|
"""Execute: message → agent → response."""
|
|
|
|
await self.sse.emit(
|
|
EventType.EXECUTION_STARTED,
|
|
{
|
|
"session_id": session.session_id,
|
|
"agent_id": session.agent_id,
|
|
"message": message[:200],
|
|
},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
# Create task
|
|
task = session.begin_task(objective=message)
|
|
task.status = TaskStatus.EXECUTING
|
|
|
|
# Execute with the selected agent
|
|
agent = BaseAgent(
|
|
profile=self.agent_profile,
|
|
model_adapter=self.model,
|
|
context_engine=self.context,
|
|
mcp_client=self.mcp,
|
|
memory_store=self.memory,
|
|
sse_emitter=self.sse,
|
|
)
|
|
|
|
try:
|
|
result = await agent.execute(
|
|
session=session,
|
|
max_steps=settings.subagent_max_steps,
|
|
)
|
|
except Exception as e:
|
|
logger.error("Execution failed: %s", e)
|
|
task.mark_failed(str(e))
|
|
session.status = SessionStatus.ERROR
|
|
await self.sse.emit(
|
|
EventType.ERROR,
|
|
{"error": "execution_failed", "message": str(e)},
|
|
session_id=session.session_id,
|
|
)
|
|
return self._error_result(session, str(e))
|
|
|
|
# Compact to history
|
|
content = result.get("content", "")
|
|
usage = result.get("usage", {"input_tokens": 0, "output_tokens": 0})
|
|
key_data = self._extract_key_data_from_results([result])
|
|
session.recent_messages = self._append_recent_messages(
|
|
session.recent_messages,
|
|
message=message,
|
|
conversation=result.get("conversation", []),
|
|
)
|
|
|
|
session.task_history.append(
|
|
self._build_task_history_entry(
|
|
task_id=task.task_id,
|
|
message=message,
|
|
content=content,
|
|
agent_id=session.agent_id,
|
|
facts=task.facts_extracted,
|
|
key_data=key_data,
|
|
tool_executions=result.get("tool_executions", []),
|
|
artifacts_count=len(result.get("artifacts", [])),
|
|
)
|
|
)
|
|
session.task_history = self._trim_task_history(session.task_history)
|
|
|
|
# Clean old artifacts
|
|
artifacts = await self.memory.list_artifacts(session.session_id)
|
|
recent_task_ids = {t["task_id"] for t in session.task_history[-2:]}
|
|
for artifact in artifacts:
|
|
if artifact.task_id not in recent_task_ids:
|
|
key = f"{self.memory._prefix}:session:{session.session_id}:artifacts"
|
|
await self.memory._r.hdel(key, artifact.artifact_id)
|
|
|
|
# Complete
|
|
task.status = TaskStatus.COMPLETED
|
|
session.complete_task()
|
|
|
|
# Calculate cost
|
|
total_input = usage.get("input_tokens", 0)
|
|
total_output = usage.get("output_tokens", 0)
|
|
cost_usd = (
|
|
(total_input / 1_000_000) * settings.cost_per_1m_input
|
|
+ (total_output / 1_000_000) * settings.cost_per_1m_output
|
|
)
|
|
|
|
await self.sse.emit(
|
|
EventType.EXECUTION_COMPLETED,
|
|
{
|
|
"session_id": session.session_id,
|
|
"task_id": task.task_id,
|
|
"agent_id": session.agent_id,
|
|
"steps_completed": 1,
|
|
"steps_failed": [],
|
|
"status": "completed",
|
|
"usage": usage,
|
|
"total_cost_usd": round(cost_usd, 6),
|
|
},
|
|
session_id=session.session_id,
|
|
)
|
|
|
|
logger.info(
|
|
"Task %s completed (agent=%s, %d tools, %d artifacts, %d input tokens)",
|
|
task.task_id,
|
|
session.agent_id,
|
|
len(result.get("tool_executions", [])),
|
|
len(result.get("artifacts", [])),
|
|
total_input,
|
|
)
|
|
|
|
return {
|
|
"session_id": session.session_id,
|
|
"task_id": task.task_id,
|
|
"agent_id": session.agent_id,
|
|
"content": content or "Task completed.",
|
|
"steps_completed": 1,
|
|
"steps_failed": [],
|
|
"artifacts_count": len(result.get("artifacts", [])),
|
|
"review": "",
|
|
"status": "completed",
|
|
"usage": usage,
|
|
"total_cost_usd": round(cost_usd, 6),
|
|
}
|
|
|
|
def _error_result(self, session: SessionState, error: str) -> dict[str, Any]:
|
|
task_id = session.current_task.task_id if session.current_task else "none"
|
|
return {
|
|
"session_id": session.session_id,
|
|
"task_id": task_id,
|
|
"content": f"Error: {error}",
|
|
"steps_completed": 0,
|
|
"steps_failed": [],
|
|
"artifacts_count": 0,
|
|
"review": "",
|
|
"status": "error",
|
|
}
|
|
|
|
@staticmethod
|
|
def _append_recent_messages(
|
|
existing: list[dict[str, Any]],
|
|
message: str,
|
|
conversation: list[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
merged = [OrchestratorEngine._sanitize_recent_message(m) for m in existing]
|
|
merged = [m for m in merged if m]
|
|
|
|
current_turn: list[dict[str, Any]] = []
|
|
if message.strip():
|
|
current_turn.append({"role": "user", "content": message})
|
|
|
|
for message_obj in conversation:
|
|
sanitized = OrchestratorEngine._sanitize_recent_message(message_obj)
|
|
if sanitized:
|
|
current_turn.append(sanitized)
|
|
|
|
merged.extend(current_turn)
|
|
return merged
|
|
|
|
@staticmethod
|
|
def _sanitize_recent_message(message: dict[str, Any]) -> dict[str, Any]:
|
|
role = str(message.get("role", "")).strip()
|
|
if role not in {"user", "assistant", "tool"}:
|
|
return {}
|
|
|
|
sanitized: dict[str, Any] = {"role": role}
|
|
content = message.get("content", "")
|
|
if isinstance(content, str) and content:
|
|
sanitized["content"] = content
|
|
|
|
if role == "assistant":
|
|
tool_calls = message.get("tool_calls")
|
|
if isinstance(tool_calls, list) and tool_calls:
|
|
sanitized["tool_calls"] = tool_calls
|
|
|
|
if role == "tool":
|
|
tool_call_id = str(message.get("tool_call_id", "")).strip()
|
|
if tool_call_id:
|
|
sanitized["tool_call_id"] = tool_call_id
|
|
|
|
if "content" not in sanitized and "tool_calls" not in sanitized:
|
|
return {}
|
|
return sanitized
|
|
|
|
@staticmethod
|
|
def _extract_key_data_from_results(results: list[dict[str, Any]]) -> dict[str, Any]:
|
|
"""Extract structured data from tool executions for task history."""
|
|
key_data: dict[str, Any] = {}
|
|
seen_tables: dict[str, list[int]] = {}
|
|
seen_sections: list[str] = []
|
|
seen_modules: list[str] = []
|
|
|
|
for result in results:
|
|
for te in result.get("tool_executions", []):
|
|
args = te.arguments
|
|
table = args.get("tableName", "")
|
|
record = args.get("recordNum")
|
|
if table and record:
|
|
record_int = int(record) if str(record).isdigit() else None
|
|
if record_int:
|
|
seen_tables.setdefault(table, [])
|
|
if record_int not in seen_tables[table]:
|
|
seen_tables[table].append(record_int)
|
|
section = args.get("sectionId", "")
|
|
if section and section not in seen_sections:
|
|
seen_sections.append(section)
|
|
module = args.get("moduleId", "") or args.get("moduleName", "")
|
|
if module and module not in seen_modules:
|
|
seen_modules.append(module)
|
|
|
|
if seen_tables:
|
|
key_data["tables"] = {t: nums[:10] for t, nums in seen_tables.items()}
|
|
if seen_sections:
|
|
key_data["sections"] = seen_sections[:20]
|
|
if seen_modules:
|
|
key_data["modules"] = seen_modules[:20]
|
|
return key_data
|
|
|
|
@staticmethod
|
|
def _build_task_history_entry(
|
|
task_id: str,
|
|
message: str,
|
|
content: str,
|
|
agent_id: str,
|
|
facts: list[str],
|
|
key_data: dict[str, Any],
|
|
tool_executions: list[Any],
|
|
artifacts_count: int,
|
|
) -> dict[str, Any]:
|
|
message_summary = " ".join(message.strip().split())[:120]
|
|
content_summary = " ".join(content.strip().split())[:160]
|
|
if content_summary:
|
|
summary = f"User: {message_summary} → Agent: {content_summary}"
|
|
else:
|
|
summary = f"User: {message_summary}"
|
|
|
|
outcomes = OrchestratorEngine._extract_outcomes(content)
|
|
focus_refs = OrchestratorEngine._extract_focus_refs(
|
|
message=message,
|
|
content=content,
|
|
key_data=key_data,
|
|
outcomes=outcomes,
|
|
)
|
|
tools_used: list[str] = []
|
|
for tool_exec in tool_executions:
|
|
tool_name = getattr(tool_exec, "tool_name", "")
|
|
if tool_name and tool_name not in tools_used:
|
|
tools_used.append(tool_name)
|
|
|
|
return {
|
|
"task_id": task_id,
|
|
"objective": message[:200],
|
|
"agent_id": agent_id,
|
|
"status": "completed",
|
|
"steps": 1,
|
|
"facts": facts[-5:],
|
|
"key_data": key_data,
|
|
"tools_used": tools_used[:8],
|
|
"artifacts_count": artifacts_count,
|
|
"summary": summary,
|
|
"outcomes": outcomes,
|
|
"focus_refs": focus_refs,
|
|
"review": "",
|
|
}
|
|
|
|
@staticmethod
|
|
def _trim_task_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
if not history:
|
|
return []
|
|
|
|
trimmed = history[-settings.task_history_max_entries:]
|
|
kept: list[dict[str, Any]] = []
|
|
total_tokens = 0
|
|
|
|
for entry in reversed(trimmed):
|
|
entry_tokens = OrchestratorEngine._estimate_task_history_entry_tokens(entry)
|
|
if kept and total_tokens + entry_tokens > settings.task_history_max_tokens:
|
|
break
|
|
kept.append(entry)
|
|
total_tokens += entry_tokens
|
|
|
|
return list(reversed(kept))
|
|
|
|
@staticmethod
|
|
def _estimate_task_history_entry_tokens(entry: dict[str, Any]) -> int:
|
|
parts = [
|
|
entry.get("objective", ""),
|
|
entry.get("summary", ""),
|
|
" ".join(entry.get("facts", [])[:5]),
|
|
" ".join(entry.get("tools_used", [])[:5]),
|
|
str(entry.get("key_data", {})),
|
|
" ".join(entry.get("outcomes", [])[:3]),
|
|
str(entry.get("focus_refs", [])[:3]),
|
|
]
|
|
return estimate_tokens("\n".join(p for p in parts if p))
|
|
|
|
@staticmethod
|
|
def _extract_outcomes(content: str) -> list[str]:
|
|
if not content:
|
|
return []
|
|
|
|
normalized_lines = []
|
|
for raw_line in content.splitlines():
|
|
line = raw_line.strip()
|
|
if not line:
|
|
continue
|
|
line = re.sub(r"^[#>\-\*\d\.\)\s]+", "", line).strip()
|
|
if not line:
|
|
continue
|
|
normalized_lines.append(line)
|
|
|
|
keywords = (
|
|
"si tuviera que elegir",
|
|
"más flojo",
|
|
"mas flojo",
|
|
"más problem",
|
|
"mas problem",
|
|
"recomiendo",
|
|
"recomendación",
|
|
"recomendacion",
|
|
"prioridad",
|
|
"conclus",
|
|
"debería",
|
|
"deberia",
|
|
"peor",
|
|
"más débil",
|
|
"mas debil",
|
|
)
|
|
|
|
outcomes: list[str] = []
|
|
seen: set[str] = set()
|
|
for line in normalized_lines:
|
|
lower = line.lower()
|
|
if any(k in lower for k in keywords):
|
|
trimmed = line[:220]
|
|
if trimmed not in seen:
|
|
seen.add(trimmed)
|
|
outcomes.append(trimmed)
|
|
if len(outcomes) >= 3:
|
|
return outcomes
|
|
|
|
for line in normalized_lines:
|
|
if len(line) < 20:
|
|
continue
|
|
trimmed = line[:180]
|
|
if trimmed not in seen:
|
|
seen.add(trimmed)
|
|
outcomes.append(trimmed)
|
|
if len(outcomes) >= 2:
|
|
break
|
|
return outcomes[:3]
|
|
|
|
@staticmethod
|
|
def _extract_focus_refs(
|
|
message: str,
|
|
content: str,
|
|
key_data: dict[str, Any],
|
|
outcomes: list[str],
|
|
) -> list[dict[str, str]]:
|
|
refs: list[dict[str, str]] = []
|
|
seen: set[tuple[str, str, str]] = set()
|
|
|
|
def add_ref(ref_type: str, label: str, ref_id: str = "", role: str = "related") -> None:
|
|
label = label.strip()
|
|
ref_id = ref_id.strip()
|
|
if not label and not ref_id:
|
|
return
|
|
key = (ref_type, label, ref_id)
|
|
if key in seen:
|
|
return
|
|
seen.add(key)
|
|
refs.append({
|
|
"type": ref_type,
|
|
"label": label or ref_id,
|
|
"id": ref_id,
|
|
"role": role,
|
|
})
|
|
|
|
for table, nums in key_data.get("tables", {}).items():
|
|
add_ref("table", table, table, "related")
|
|
for num in nums[:3]:
|
|
add_ref("record", f"{table} record {num}", f"{table}:{num}", "related")
|
|
|
|
for section in key_data.get("sections", [])[:5]:
|
|
add_ref("section", section, section, "related")
|
|
|
|
for module in key_data.get("modules", [])[:5]:
|
|
add_ref("module", module, module, "related")
|
|
|
|
source_text = "\n".join(outcomes + [content[:1200]])
|
|
for line in outcomes:
|
|
for match in re.findall(r"\*\*([^*]{2,80})\*\*", line):
|
|
add_ref(
|
|
OrchestratorEngine._infer_ref_type(match, line, message),
|
|
match,
|
|
"",
|
|
"primary_focus",
|
|
)
|
|
|
|
if not any(ref["role"] == "primary_focus" for ref in refs):
|
|
for pattern in (
|
|
r"(?:elegir(?:\s+\*\*uno\*\*)?,?\s+dir[ií]a que\s+\*\*([^*]{2,80})\*\*)",
|
|
r"(?:el [^.\n]{0,40}m[aá]s flojo(?:[^.\n]{0,40})es\s+\*\*([^*]{2,80})\*\*)",
|
|
):
|
|
match = re.search(pattern, source_text, flags=re.IGNORECASE)
|
|
if match:
|
|
label = match.group(1).strip()
|
|
add_ref(
|
|
OrchestratorEngine._infer_ref_type(label, source_text, message),
|
|
label,
|
|
"",
|
|
"primary_focus",
|
|
)
|
|
break
|
|
|
|
return refs[:8]
|
|
|
|
@staticmethod
|
|
def _infer_ref_type(label: str, context: str, message: str) -> str:
|
|
text = f"{label} {context} {message}".lower()
|
|
if any(k in text for k in ("módulo", "modulo")):
|
|
return "module"
|
|
if any(k in text for k in ("página", "pagina", "apartado")):
|
|
return "page"
|
|
if "tabla" in text:
|
|
return "table"
|
|
if any(k in text for k in ("archivo", "file", ".tpl", ".php", ".js", ".css")):
|
|
return "file"
|
|
if any(k in text for k in ("sección", "seccion", "section")):
|
|
return "section"
|
|
return "entity"
|