"""Redis storage layer for session persistence. Key structure: {prefix}:session:{id} — SessionState JSON {prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id {prefix}:session:{id}:events — List of SSE event JSONs {prefix}:session:{id}:lock — Execution lock (SETNX) {prefix}:sessions:index — Set of all active session IDs """ from __future__ import annotations import json import logging from contextlib import asynccontextmanager from typing import Any, AsyncIterator import redis.asyncio as redis from ..config import settings from ..models.session import SessionState logger = logging.getLogger(__name__) class RedisStorage: """Async Redis storage for session state.""" def __init__(self, redis_client: redis.Redis | None = None) -> None: self._r = redis_client self._prefix = settings.redis_key_prefix async def connect(self) -> None: """Create Redis connection if not provided.""" if self._r is None: self._r = redis.from_url( settings.redis_url, decode_responses=True, ) await self._r.ping() logger.info("Connected to Redis at %s", settings.redis_url) async def disconnect(self) -> None: if self._r: await self._r.aclose() @property def client(self) -> redis.Redis: if self._r is None: raise RuntimeError("Redis not connected. Call connect() first.") return self._r # ------------------------------------------------------------------ # Key helpers # ------------------------------------------------------------------ def _key(self, *parts: str) -> str: return ":".join([self._prefix, *parts]) # ------------------------------------------------------------------ # Session CRUD # ------------------------------------------------------------------ async def create_session(self, session: SessionState) -> None: key = self._key("session", session.session_id) await self.client.set( key, session.model_dump_json(), ex=settings.session_ttl_seconds, ) await self.client.sadd( self._key("sessions", "index"), session.session_id ) async def get_session(self, session_id: str) -> SessionState | None: key = self._key("session", session_id) data = await self.client.get(key) if data: return SessionState.model_validate_json(data) return None async def update_session(self, session: SessionState) -> None: key = self._key("session", session.session_id) await self.client.set( key, session.model_dump_json(), ex=settings.session_ttl_seconds, ) async def delete_session(self, session_id: str) -> bool: pipe = self.client.pipeline() pipe.delete(self._key("session", session_id)) pipe.delete(self._key("session", session_id, "artifacts")) pipe.delete(self._key("session", session_id, "events")) pipe.srem(self._key("sessions", "index"), session_id) results = await pipe.execute() return bool(results[0]) async def list_sessions(self) -> list[str]: members = await self.client.smembers(self._key("sessions", "index")) return list(members) # ------------------------------------------------------------------ # Event log # ------------------------------------------------------------------ async def append_event( self, session_id: str, event: dict[str, Any] ) -> None: key = self._key("session", session_id, "events") length = await self.client.rpush(key, json.dumps(event)) # Set TTL on first event if length == 1: await self.client.expire(key, settings.session_ttl_seconds) # Trim only when significantly over cap (avoids race conditions) if length > 600: await self.client.ltrim(key, -500, -1) async def get_events( self, session_id: str, start: int = 0, end: int = -1 ) -> list[dict[str, Any]]: key = self._key("session", session_id, "events") raw = await self.client.lrange(key, start, end) return [json.loads(r) for r in raw] # ------------------------------------------------------------------ # Execution lock (prevents concurrent messages on same session) # ------------------------------------------------------------------ @asynccontextmanager async def session_lock( self, session_id: str, timeout: int = 300 ) -> AsyncIterator[bool]: """Acquire an exclusive execution lock for a session. Uses SETNX with auto-expiry to prevent deadlocks if the process crashes mid-execution. Usage: async with storage.session_lock(session_id) as acquired: if not acquired: raise HTTPException(409, "Session busy") # ... execute ... """ key = self._key("session", session_id, "lock") acquired = await self.client.set(key, "1", nx=True, ex=timeout) try: yield bool(acquired) finally: if acquired: await self.client.delete(key) async def clear_session_lock(self, session_id: str) -> None: """Borra el lock de ejecución de una sesión de forma incondicional. Usado por el endpoint de abort para liberar un lock huérfano (de una ejecución previa que crasheó antes de soltarlo) y no bloquear el siguiente mensaje hasta que expire el TTL. """ key = self._key("session", session_id, "lock") await self.client.delete(key)