Initial commit

This commit is contained in:
Jordan
2026-04-01 23:16:45 +01:00
commit 91cfdaee72
200 changed files with 25589 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .sse import SSEEmitter, EventType
__all__ = ["SSEEmitter", "EventType"]

Binary file not shown.

Binary file not shown.

186
src/streaming/sse.py Normal file
View File

@@ -0,0 +1,186 @@
"""SSE (Server-Sent Events) streaming system.
Supports structured event types, per-session event queues,
and Redis-backed persistence.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from enum import StrEnum
from typing import Any, AsyncIterator
from ..config import settings
logger = logging.getLogger(__name__)
class EventType(StrEnum):
SESSION_CREATED = "session.created"
EXECUTION_STARTED = "execution.started"
AGENT_DELTA = "agent.delta"
TOOL_STARTED = "tool.started"
TOOL_COMPLETED = "tool.completed"
SUBAGENT_ASSIGNED = "subagent.assigned"
EXECUTION_COMPLETED = "execution.completed"
ERROR = "error"
KEEPALIVE = "keepalive"
class SSEEvent:
"""A single SSE event."""
__slots__ = ("event_type", "data", "session_id", "timestamp", "event_id")
def __init__(
self,
event_type: EventType,
data: dict[str, Any],
session_id: str,
event_id: str = "",
) -> None:
self.event_type = event_type
self.data = data
self.session_id = session_id
self.timestamp = datetime.now(timezone.utc).isoformat()
self.event_id = event_id or f"{int(time.time() * 1000)}"
def format_sse(self) -> str:
"""Format as SSE wire protocol."""
payload = {
"type": self.event_type.value,
"data": self.data,
"timestamp": self.timestamp,
}
lines = [
f"id: {self.event_id}",
f"event: {self.event_type.value}",
f"data: {json.dumps(payload)}",
"",
"",
]
return "\n".join(lines)
def to_dict(self) -> dict[str, Any]:
return {
"event_id": self.event_id,
"type": self.event_type.value,
"data": self.data,
"timestamp": self.timestamp,
}
class SSEEmitter:
"""Manages per-session SSE event queues with Redis persistence."""
def __init__(self, redis_storage=None) -> None:
self._queues: dict[str, list[asyncio.Queue[SSEEvent | None]]] = {}
self._history: dict[str, list[SSEEvent]] = {}
self._max_history = 500
self._storage = redis_storage
def set_storage(self, redis_storage) -> None:
"""Set the Redis storage backend (called after startup)."""
self._storage = redis_storage
async def emit(
self,
event_type: EventType,
data: dict[str, Any],
session_id: str,
) -> None:
"""Emit an event to all listeners of a session."""
event = SSEEvent(
event_type=event_type,
data=data,
session_id=session_id,
)
# Store in memory
if session_id not in self._history:
self._history[session_id] = []
history = self._history[session_id]
history.append(event)
if len(history) > self._max_history:
self._history[session_id] = history[-self._max_history:]
# Persist to Redis (skip keepalives — they're noise)
if self._storage and event_type != EventType.KEEPALIVE:
try:
await self._storage.append_event(session_id, event.to_dict())
logger.debug("Persisted event %s for session %s", event_type.value, session_id[:8])
except Exception as e:
logger.warning("Failed to persist event to Redis: %s — storage type: %s", e, type(self._storage).__name__)
# Push to all active queues for this session
queues = self._queues.get(session_id, [])
for q in queues:
try:
q.put_nowait(event)
except asyncio.QueueFull:
logger.warning("SSE queue full for session %s", session_id)
async def subscribe(
self, session_id: str
) -> AsyncIterator[str]:
"""Subscribe to SSE events for a session.
Yields formatted SSE strings. Sends keepalives to detect
disconnected clients.
"""
queue: asyncio.Queue[SSEEvent | None] = asyncio.Queue(maxsize=256)
if session_id not in self._queues:
self._queues[session_id] = []
self._queues[session_id].append(queue)
try:
while True:
try:
event = await asyncio.wait_for(
queue.get(),
timeout=settings.sse_keepalive_seconds,
)
if event is None:
break
yield event.format_sse()
except asyncio.TimeoutError:
# Send keepalive
keepalive = SSEEvent(
event_type=EventType.KEEPALIVE,
data={},
session_id=session_id,
)
yield keepalive.format_sse()
finally:
if queue in self._queues.get(session_id, []):
self._queues[session_id].remove(queue)
async def get_history(self, session_id: str) -> list[dict[str, Any]]:
"""Return event history — from Redis if available, fallback to memory."""
# Try Redis first (persistent)
if self._storage:
try:
events = await self._storage.get_events(session_id)
if events:
return events
except Exception as e:
logger.warning("Failed to read events from Redis: %s", e)
# Fallback to in-memory
events = self._history.get(session_id, [])
return [e.to_dict() for e in events]
def cleanup_session(self, session_id: str) -> None:
"""Remove all queues and in-memory history for a session."""
for q in self._queues.get(session_id, []):
try:
q.put_nowait(None)
except asyncio.QueueFull:
pass
self._queues.pop(session_id, None)
self._history.pop(session_id, None)