Antes, al parar el agente y mandar un mensaje nuevo, la ejecución previa
seguía viva reteniendo el session_lock: el mensaje nuevo recibía "busy" y el
stream mostraba la ejecución anterior. La tarea detached (create_task) no se
guardaba en ningún sitio y era imposible cancelarla.
- _running_executions: registro de la tarea asyncio por session_id.
- _cancel_running_execution(): cancela y espera a que libere el lock.
- send_message: preempt — un mensaje nuevo cancela la ejecución previa.
- _execute_and_persist: maneja CancelledError dejando la sesión en ACTIVE.
- POST /sessions/{id}/abort: cancela, cierra el stream SSE y limpia el lock.
- RedisStorage.clear_session_lock(): libera locks huérfanos.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
162 lines
5.7 KiB
Python
162 lines
5.7 KiB
Python
"""Redis storage layer for session persistence.
|
|
|
|
Key structure:
|
|
{prefix}:session:{id} — SessionState JSON
|
|
{prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id
|
|
{prefix}:session:{id}:events — List of SSE event JSONs
|
|
{prefix}:session:{id}:lock — Execution lock (SETNX)
|
|
{prefix}:sessions:index — Set of all active session IDs
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Any, AsyncIterator
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from ..config import settings
|
|
from ..models.session import SessionState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisStorage:
|
|
"""Async Redis storage for session state."""
|
|
|
|
def __init__(self, redis_client: redis.Redis | None = None) -> None:
|
|
self._r = redis_client
|
|
self._prefix = settings.redis_key_prefix
|
|
|
|
async def connect(self) -> None:
|
|
"""Create Redis connection if not provided."""
|
|
if self._r is None:
|
|
self._r = redis.from_url(
|
|
settings.redis_url,
|
|
decode_responses=True,
|
|
)
|
|
await self._r.ping()
|
|
logger.info("Connected to Redis at %s", settings.redis_url)
|
|
|
|
async def disconnect(self) -> None:
|
|
if self._r:
|
|
await self._r.aclose()
|
|
|
|
@property
|
|
def client(self) -> redis.Redis:
|
|
if self._r is None:
|
|
raise RuntimeError("Redis not connected. Call connect() first.")
|
|
return self._r
|
|
|
|
# ------------------------------------------------------------------
|
|
# Key helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _key(self, *parts: str) -> str:
|
|
return ":".join([self._prefix, *parts])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Session CRUD
|
|
# ------------------------------------------------------------------
|
|
|
|
async def create_session(self, session: SessionState) -> None:
|
|
key = self._key("session", session.session_id)
|
|
await self.client.set(
|
|
key,
|
|
session.model_dump_json(),
|
|
ex=settings.session_ttl_seconds,
|
|
)
|
|
await self.client.sadd(
|
|
self._key("sessions", "index"), session.session_id
|
|
)
|
|
|
|
async def get_session(self, session_id: str) -> SessionState | None:
|
|
key = self._key("session", session_id)
|
|
data = await self.client.get(key)
|
|
if data:
|
|
return SessionState.model_validate_json(data)
|
|
return None
|
|
|
|
async def update_session(self, session: SessionState) -> None:
|
|
key = self._key("session", session.session_id)
|
|
await self.client.set(
|
|
key,
|
|
session.model_dump_json(),
|
|
ex=settings.session_ttl_seconds,
|
|
)
|
|
|
|
async def delete_session(self, session_id: str) -> bool:
|
|
pipe = self.client.pipeline()
|
|
pipe.delete(self._key("session", session_id))
|
|
pipe.delete(self._key("session", session_id, "artifacts"))
|
|
pipe.delete(self._key("session", session_id, "events"))
|
|
pipe.srem(self._key("sessions", "index"), session_id)
|
|
results = await pipe.execute()
|
|
return bool(results[0])
|
|
|
|
async def list_sessions(self) -> list[str]:
|
|
members = await self.client.smembers(self._key("sessions", "index"))
|
|
return list(members)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Event log
|
|
# ------------------------------------------------------------------
|
|
|
|
async def append_event(
|
|
self, session_id: str, event: dict[str, Any]
|
|
) -> None:
|
|
key = self._key("session", session_id, "events")
|
|
length = await self.client.rpush(key, json.dumps(event))
|
|
# Set TTL on first event
|
|
if length == 1:
|
|
await self.client.expire(key, settings.session_ttl_seconds)
|
|
# Trim only when significantly over cap (avoids race conditions)
|
|
if length > 600:
|
|
await self.client.ltrim(key, -500, -1)
|
|
|
|
async def get_events(
|
|
self, session_id: str, start: int = 0, end: int = -1
|
|
) -> list[dict[str, Any]]:
|
|
key = self._key("session", session_id, "events")
|
|
raw = await self.client.lrange(key, start, end)
|
|
return [json.loads(r) for r in raw]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Execution lock (prevents concurrent messages on same session)
|
|
# ------------------------------------------------------------------
|
|
|
|
@asynccontextmanager
|
|
async def session_lock(
|
|
self, session_id: str, timeout: int = 300
|
|
) -> AsyncIterator[bool]:
|
|
"""Acquire an exclusive execution lock for a session.
|
|
|
|
Uses SETNX with auto-expiry to prevent deadlocks if the process
|
|
crashes mid-execution.
|
|
|
|
Usage:
|
|
async with storage.session_lock(session_id) as acquired:
|
|
if not acquired:
|
|
raise HTTPException(409, "Session busy")
|
|
# ... execute ...
|
|
"""
|
|
key = self._key("session", session_id, "lock")
|
|
acquired = await self.client.set(key, "1", nx=True, ex=timeout)
|
|
try:
|
|
yield bool(acquired)
|
|
finally:
|
|
if acquired:
|
|
await self.client.delete(key)
|
|
|
|
async def clear_session_lock(self, session_id: str) -> None:
|
|
"""Borra el lock de ejecución de una sesión de forma incondicional.
|
|
|
|
Usado por el endpoint de abort para liberar un lock huérfano (de una
|
|
ejecución previa que crasheó antes de soltarlo) y no bloquear el
|
|
siguiente mensaje hasta que expire el TTL.
|
|
"""
|
|
key = self._key("session", session_id, "lock")
|
|
await self.client.delete(key)
|