diff --git a/src/api/routes.py b/src/api/routes.py index 7fcf5e2..e96c27f 100644 --- a/src/api/routes.py +++ b/src/api/routes.py @@ -94,6 +94,45 @@ class SessionResponse(BaseModel): _deps: dict[str, Any] = {} +# ------------------------------------------------------------------ +# Registro de ejecuciones en curso (para abort / preempt) +# ------------------------------------------------------------------ +# El envío de mensajes en modo stream arranca una tarea asyncio "detached" +# (create_task) que corre independiente de la conexión SSE del cliente. Sin una +# referencia a esa tarea era imposible cancelarla: si el usuario paraba el +# stream en el frontend, la tarea seguía viva reteniendo el session_lock, y el +# siguiente mensaje recibía "busy" mientras el stream mostraba la ejecución +# anterior. Guardamos la tarea por session_id para poder cancelarla (abort +# explícito del usuario o preempt al llegar un mensaje nuevo). +_running_executions: dict[str, "asyncio.Task[Any]"] = {} + + +async def _cancel_running_execution(session_id: str, *, reason: str) -> bool: + """Cancela la ejecución en curso de una sesión, si la hay. + + Espera a que la tarea termine de desenrollarse para garantizar que su + `finally` libere el session_lock (SETNX en Redis) antes de devolver. Así el + siguiente mensaje puede adquirir el lock de inmediato. Idempotente. + + Devuelve True si había una ejecución activa que se canceló. + """ + task = _running_executions.get(session_id) + if task is None or task.done(): + return False + logger.info("Cancelling running execution for session %s (%s)", session_id, reason) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception as e: # noqa: BLE001 — la tarea ya está muriendo + logger.warning("Error while cancelling execution for %s: %s", session_id, e) + finally: + if _running_executions.get(session_id) is task: + _running_executions.pop(session_id, None) + return True + + def set_dependencies( storage: Any, model_adapter: Any, @@ -312,8 +351,26 @@ async def send_message( from ..mcp.manager import MCPManager orchestrator = _build_orchestrator(mcp_manager or MCPManager(), agent_profile) + # Preempt: si ya hay una ejecución en curso para esta sesión (p.ej. el + # usuario paró el stream y mandó un mensaje nuevo), la cancelamos antes de + # arrancar. _cancel_running_execution espera a que libere el session_lock, + # de modo que el create_task de abajo no choque con un "busy". + await _cancel_running_execution(session_id, reason="preempted by new message") + if body.stream: - asyncio.create_task(_execute_and_persist(orchestrator, storage, session, body.message)) + task = asyncio.create_task( + _execute_and_persist(orchestrator, storage, session, body.message) + ) + _running_executions[session_id] = task + # Auto-limpieza del registro al terminar (solo si seguimos siendo la + # tarea activa — un preempt posterior pudo reemplazarla ya). + task.add_done_callback( + lambda t, sid=session_id: ( + _running_executions.pop(sid, None) + if _running_executions.get(sid) is t + else None + ) + ) return { "session_id": session_id, "status": "executing", @@ -337,6 +394,16 @@ async def _execute_and_persist(orchestrator, storage, session, message) -> dict[ try: result = await orchestrator.process_message(session, message) return result + except asyncio.CancelledError: + # Ejecución abortada por el usuario (stop) o preemptada por un + # mensaje nuevo. Dejamos la sesión en estado consistente (NO ERROR) + # para que el siguiente mensaje arranque limpio, y re-lanzamos para + # que el `await task` de la cancelación complete. El `finally` + # persiste el estado y el `session_lock` se libera al salir. + logger.info("Execution cancelled for session %s", session.session_id) + session.status = SessionStatus.ACTIVE + session.current_task = None + raise except Exception as e: session.status = SessionStatus.ERROR logger.exception("Execution failed for session %s", session.session_id) @@ -352,6 +419,52 @@ async def _execute_and_persist(orchestrator, storage, session, message) -> dict[ logger.error("Failed to persist session state: %s", e) +# ------------------------------------------------------------------ +# POST /sessions/{id}/abort — cancela la ejecución en curso +# ------------------------------------------------------------------ + +@router.post("/sessions/{session_id}/abort") +async def abort_session(session_id: str) -> dict[str, Any]: + """Cancela la ejecución en curso de una sesión (botón Stop del chat). + + Cancela la tarea detached (liberando el session_lock), cierra el stream SSE + de los suscriptores y limpia un posible lock huérfano. Idempotente: si no + hay nada en curso devuelve `no_active_execution` sin error. + """ + storage = _get_storage() + session = await storage.get_session(session_id) + if not session: + raise HTTPException(status_code=404, detail="Session not found") + + cancelled = await _cancel_running_execution(session_id, reason="user abort") + + # Cerrar el stream para que los suscriptores SSE (native + claude) terminen + # limpio. EXECUTION_COMPLETED se traduce a un {"type":"done"} en el formato + # claude que consume el frontend. + try: + sse = _get_sse() + await sse.emit( + EventType.EXECUTION_COMPLETED, + {"session_id": session_id, "aborted": True}, + session_id=session_id, + ) + sse.cleanup_session(session_id) + except Exception as e: + logger.warning("Failed to close SSE stream on abort for %s: %s", session_id, e) + + # Defensa: liberar un lock huérfano (p.ej. de una ejecución previa que crasheó + # antes de soltarlo) para no bloquear el siguiente mensaje hasta el TTL. + try: + await storage.clear_session_lock(session_id) + except Exception as e: + logger.warning("Failed to clear session lock on abort for %s: %s", session_id, e) + + return { + "session_id": session_id, + "status": "aborted" if cancelled else "no_active_execution", + } + + # ------------------------------------------------------------------ # GET /sessions/{id}/stream # ------------------------------------------------------------------ diff --git a/src/storage/redis.py b/src/storage/redis.py index 7c9131c..ac0603c 100644 --- a/src/storage/redis.py +++ b/src/storage/redis.py @@ -149,3 +149,13 @@ class RedisStorage: finally: if acquired: await self.client.delete(key) + + async def clear_session_lock(self, session_id: str) -> None: + """Borra el lock de ejecución de una sesión de forma incondicional. + + Usado por el endpoint de abort para liberar un lock huérfano (de una + ejecución previa que crasheó antes de soltarlo) y no bloquear el + siguiente mensaje hasta que expire el TTL. + """ + key = self._key("session", session_id, "lock") + await self.client.delete(key)