Initial commit
This commit is contained in:
3
src/storage/__init__.py
Normal file
3
src/storage/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .redis import RedisStorage
|
||||
|
||||
__all__ = ["RedisStorage"]
|
||||
BIN
src/storage/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/storage/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/storage/__pycache__/redis.cpython-312.pyc
Normal file
BIN
src/storage/__pycache__/redis.cpython-312.pyc
Normal file
Binary file not shown.
151
src/storage/redis.py
Normal file
151
src/storage/redis.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Redis storage layer for session persistence.
|
||||
|
||||
Key structure:
|
||||
{prefix}:session:{id} — SessionState JSON
|
||||
{prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id
|
||||
{prefix}:session:{id}:events — List of SSE event JSONs
|
||||
{prefix}:session:{id}:lock — Execution lock (SETNX)
|
||||
{prefix}:sessions:index — Set of all active session IDs
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from ..config import settings
|
||||
from ..models.session import SessionState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisStorage:
|
||||
"""Async Redis storage for session state."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis | None = None) -> None:
|
||||
self._r = redis_client
|
||||
self._prefix = settings.redis_key_prefix
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Create Redis connection if not provided."""
|
||||
if self._r is None:
|
||||
self._r = redis.from_url(
|
||||
settings.redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
await self._r.ping()
|
||||
logger.info("Connected to Redis at %s", settings.redis_url)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
if self._r:
|
||||
await self._r.aclose()
|
||||
|
||||
@property
|
||||
def client(self) -> redis.Redis:
|
||||
if self._r is None:
|
||||
raise RuntimeError("Redis not connected. Call connect() first.")
|
||||
return self._r
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, *parts: str) -> str:
|
||||
return ":".join([self._prefix, *parts])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Session CRUD
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def create_session(self, session: SessionState) -> None:
|
||||
key = self._key("session", session.session_id)
|
||||
await self.client.set(
|
||||
key,
|
||||
session.model_dump_json(),
|
||||
ex=settings.session_ttl_seconds,
|
||||
)
|
||||
await self.client.sadd(
|
||||
self._key("sessions", "index"), session.session_id
|
||||
)
|
||||
|
||||
async def get_session(self, session_id: str) -> SessionState | None:
|
||||
key = self._key("session", session_id)
|
||||
data = await self.client.get(key)
|
||||
if data:
|
||||
return SessionState.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def update_session(self, session: SessionState) -> None:
|
||||
key = self._key("session", session.session_id)
|
||||
await self.client.set(
|
||||
key,
|
||||
session.model_dump_json(),
|
||||
ex=settings.session_ttl_seconds,
|
||||
)
|
||||
|
||||
async def delete_session(self, session_id: str) -> bool:
|
||||
pipe = self.client.pipeline()
|
||||
pipe.delete(self._key("session", session_id))
|
||||
pipe.delete(self._key("session", session_id, "artifacts"))
|
||||
pipe.delete(self._key("session", session_id, "events"))
|
||||
pipe.srem(self._key("sessions", "index"), session_id)
|
||||
results = await pipe.execute()
|
||||
return bool(results[0])
|
||||
|
||||
async def list_sessions(self) -> list[str]:
|
||||
members = await self.client.smembers(self._key("sessions", "index"))
|
||||
return list(members)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event log
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def append_event(
|
||||
self, session_id: str, event: dict[str, Any]
|
||||
) -> None:
|
||||
key = self._key("session", session_id, "events")
|
||||
length = await self.client.rpush(key, json.dumps(event))
|
||||
# Set TTL on first event
|
||||
if length == 1:
|
||||
await self.client.expire(key, settings.session_ttl_seconds)
|
||||
# Trim only when significantly over cap (avoids race conditions)
|
||||
if length > 600:
|
||||
await self.client.ltrim(key, -500, -1)
|
||||
|
||||
async def get_events(
|
||||
self, session_id: str, start: int = 0, end: int = -1
|
||||
) -> list[dict[str, Any]]:
|
||||
key = self._key("session", session_id, "events")
|
||||
raw = await self.client.lrange(key, start, end)
|
||||
return [json.loads(r) for r in raw]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Execution lock (prevents concurrent messages on same session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_lock(
|
||||
self, session_id: str, timeout: int = 300
|
||||
) -> AsyncIterator[bool]:
|
||||
"""Acquire an exclusive execution lock for a session.
|
||||
|
||||
Uses SETNX with auto-expiry to prevent deadlocks if the process
|
||||
crashes mid-execution.
|
||||
|
||||
Usage:
|
||||
async with storage.session_lock(session_id) as acquired:
|
||||
if not acquired:
|
||||
raise HTTPException(409, "Session busy")
|
||||
# ... execute ...
|
||||
"""
|
||||
key = self._key("session", session_id, "lock")
|
||||
acquired = await self.client.set(key, "1", nx=True, ex=timeout)
|
||||
try:
|
||||
yield bool(acquired)
|
||||
finally:
|
||||
if acquired:
|
||||
await self.client.delete(key)
|
||||
Reference in New Issue
Block a user