Files
agenticSystem/src/storage/redis.py
Jordan Diaz 36318c61ea fix(chat): permitir abortar/preemptar ejecución en curso de una sesión
Antes, al parar el agente y mandar un mensaje nuevo, la ejecución previa
seguía viva reteniendo el session_lock: el mensaje nuevo recibía "busy" y el
stream mostraba la ejecución anterior. La tarea detached (create_task) no se
guardaba en ningún sitio y era imposible cancelarla.

- _running_executions: registro de la tarea asyncio por session_id.
- _cancel_running_execution(): cancela y espera a que libere el lock.
- send_message: preempt — un mensaje nuevo cancela la ejecución previa.
- _execute_and_persist: maneja CancelledError dejando la sesión en ACTIVE.
- POST /sessions/{id}/abort: cancela, cierra el stream SSE y limpia el lock.
- RedisStorage.clear_session_lock(): libera locks huérfanos.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-02 17:50:46 +00:00

162 lines
5.7 KiB
Python

"""Redis storage layer for session persistence.
Key structure:
{prefix}:session:{id} — SessionState JSON
{prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id
{prefix}:session:{id}:events — List of SSE event JSONs
{prefix}:session:{id}:lock — Execution lock (SETNX)
{prefix}:sessions:index — Set of all active session IDs
"""
from __future__ import annotations
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator
import redis.asyncio as redis
from ..config import settings
from ..models.session import SessionState
logger = logging.getLogger(__name__)
class RedisStorage:
"""Async Redis storage for session state."""
def __init__(self, redis_client: redis.Redis | None = None) -> None:
self._r = redis_client
self._prefix = settings.redis_key_prefix
async def connect(self) -> None:
"""Create Redis connection if not provided."""
if self._r is None:
self._r = redis.from_url(
settings.redis_url,
decode_responses=True,
)
await self._r.ping()
logger.info("Connected to Redis at %s", settings.redis_url)
async def disconnect(self) -> None:
if self._r:
await self._r.aclose()
@property
def client(self) -> redis.Redis:
if self._r is None:
raise RuntimeError("Redis not connected. Call connect() first.")
return self._r
# ------------------------------------------------------------------
# Key helpers
# ------------------------------------------------------------------
def _key(self, *parts: str) -> str:
return ":".join([self._prefix, *parts])
# ------------------------------------------------------------------
# Session CRUD
# ------------------------------------------------------------------
async def create_session(self, session: SessionState) -> None:
key = self._key("session", session.session_id)
await self.client.set(
key,
session.model_dump_json(),
ex=settings.session_ttl_seconds,
)
await self.client.sadd(
self._key("sessions", "index"), session.session_id
)
async def get_session(self, session_id: str) -> SessionState | None:
key = self._key("session", session_id)
data = await self.client.get(key)
if data:
return SessionState.model_validate_json(data)
return None
async def update_session(self, session: SessionState) -> None:
key = self._key("session", session.session_id)
await self.client.set(
key,
session.model_dump_json(),
ex=settings.session_ttl_seconds,
)
async def delete_session(self, session_id: str) -> bool:
pipe = self.client.pipeline()
pipe.delete(self._key("session", session_id))
pipe.delete(self._key("session", session_id, "artifacts"))
pipe.delete(self._key("session", session_id, "events"))
pipe.srem(self._key("sessions", "index"), session_id)
results = await pipe.execute()
return bool(results[0])
async def list_sessions(self) -> list[str]:
members = await self.client.smembers(self._key("sessions", "index"))
return list(members)
# ------------------------------------------------------------------
# Event log
# ------------------------------------------------------------------
async def append_event(
self, session_id: str, event: dict[str, Any]
) -> None:
key = self._key("session", session_id, "events")
length = await self.client.rpush(key, json.dumps(event))
# Set TTL on first event
if length == 1:
await self.client.expire(key, settings.session_ttl_seconds)
# Trim only when significantly over cap (avoids race conditions)
if length > 600:
await self.client.ltrim(key, -500, -1)
async def get_events(
self, session_id: str, start: int = 0, end: int = -1
) -> list[dict[str, Any]]:
key = self._key("session", session_id, "events")
raw = await self.client.lrange(key, start, end)
return [json.loads(r) for r in raw]
# ------------------------------------------------------------------
# Execution lock (prevents concurrent messages on same session)
# ------------------------------------------------------------------
@asynccontextmanager
async def session_lock(
self, session_id: str, timeout: int = 300
) -> AsyncIterator[bool]:
"""Acquire an exclusive execution lock for a session.
Uses SETNX with auto-expiry to prevent deadlocks if the process
crashes mid-execution.
Usage:
async with storage.session_lock(session_id) as acquired:
if not acquired:
raise HTTPException(409, "Session busy")
# ... execute ...
"""
key = self._key("session", session_id, "lock")
acquired = await self.client.set(key, "1", nx=True, ex=timeout)
try:
yield bool(acquired)
finally:
if acquired:
await self.client.delete(key)
async def clear_session_lock(self, session_id: str) -> None:
"""Borra el lock de ejecución de una sesión de forma incondicional.
Usado por el endpoint de abort para liberar un lock huérfano (de una
ejecución previa que crasheó antes de soltarlo) y no bloquear el
siguiente mensaje hasta que expire el TTL.
"""
key = self._key("session", session_id, "lock")
await self.client.delete(key)