Initial commit

This commit is contained in:
Jordan
2026-04-01 23:16:45 +01:00
commit 91cfdaee72
200 changed files with 25589 additions and 0 deletions

3
src/storage/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .redis import RedisStorage
__all__ = ["RedisStorage"]

Binary file not shown.

Binary file not shown.

151
src/storage/redis.py Normal file
View File

@@ -0,0 +1,151 @@
"""Redis storage layer for session persistence.
Key structure:
{prefix}:session:{id} — SessionState JSON
{prefix}:session:{id}:artifacts — Hash of ArtifactSummary by artifact_id
{prefix}:session:{id}:events — List of SSE event JSONs
{prefix}:session:{id}:lock — Execution lock (SETNX)
{prefix}:sessions:index — Set of all active session IDs
"""
from __future__ import annotations
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator
import redis.asyncio as redis
from ..config import settings
from ..models.session import SessionState
logger = logging.getLogger(__name__)
class RedisStorage:
"""Async Redis storage for session state."""
def __init__(self, redis_client: redis.Redis | None = None) -> None:
self._r = redis_client
self._prefix = settings.redis_key_prefix
async def connect(self) -> None:
"""Create Redis connection if not provided."""
if self._r is None:
self._r = redis.from_url(
settings.redis_url,
decode_responses=True,
)
await self._r.ping()
logger.info("Connected to Redis at %s", settings.redis_url)
async def disconnect(self) -> None:
if self._r:
await self._r.aclose()
@property
def client(self) -> redis.Redis:
if self._r is None:
raise RuntimeError("Redis not connected. Call connect() first.")
return self._r
# ------------------------------------------------------------------
# Key helpers
# ------------------------------------------------------------------
def _key(self, *parts: str) -> str:
return ":".join([self._prefix, *parts])
# ------------------------------------------------------------------
# Session CRUD
# ------------------------------------------------------------------
async def create_session(self, session: SessionState) -> None:
key = self._key("session", session.session_id)
await self.client.set(
key,
session.model_dump_json(),
ex=settings.session_ttl_seconds,
)
await self.client.sadd(
self._key("sessions", "index"), session.session_id
)
async def get_session(self, session_id: str) -> SessionState | None:
key = self._key("session", session_id)
data = await self.client.get(key)
if data:
return SessionState.model_validate_json(data)
return None
async def update_session(self, session: SessionState) -> None:
key = self._key("session", session.session_id)
await self.client.set(
key,
session.model_dump_json(),
ex=settings.session_ttl_seconds,
)
async def delete_session(self, session_id: str) -> bool:
pipe = self.client.pipeline()
pipe.delete(self._key("session", session_id))
pipe.delete(self._key("session", session_id, "artifacts"))
pipe.delete(self._key("session", session_id, "events"))
pipe.srem(self._key("sessions", "index"), session_id)
results = await pipe.execute()
return bool(results[0])
async def list_sessions(self) -> list[str]:
members = await self.client.smembers(self._key("sessions", "index"))
return list(members)
# ------------------------------------------------------------------
# Event log
# ------------------------------------------------------------------
async def append_event(
self, session_id: str, event: dict[str, Any]
) -> None:
key = self._key("session", session_id, "events")
length = await self.client.rpush(key, json.dumps(event))
# Set TTL on first event
if length == 1:
await self.client.expire(key, settings.session_ttl_seconds)
# Trim only when significantly over cap (avoids race conditions)
if length > 600:
await self.client.ltrim(key, -500, -1)
async def get_events(
self, session_id: str, start: int = 0, end: int = -1
) -> list[dict[str, Any]]:
key = self._key("session", session_id, "events")
raw = await self.client.lrange(key, start, end)
return [json.loads(r) for r in raw]
# ------------------------------------------------------------------
# Execution lock (prevents concurrent messages on same session)
# ------------------------------------------------------------------
@asynccontextmanager
async def session_lock(
self, session_id: str, timeout: int = 300
) -> AsyncIterator[bool]:
"""Acquire an exclusive execution lock for a session.
Uses SETNX with auto-expiry to prevent deadlocks if the process
crashes mid-execution.
Usage:
async with storage.session_lock(session_id) as acquired:
if not acquired:
raise HTTPException(409, "Session busy")
# ... execute ...
"""
key = self._key("session", session_id, "lock")
acquired = await self.client.set(key, "1", nx=True, ex=timeout)
try:
yield bool(acquired)
finally:
if acquired:
await self.client.delete(key)