Initial commit
This commit is contained in:
3
src/streaming/__init__.py
Normal file
3
src/streaming/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .sse import SSEEmitter, EventType
|
||||
|
||||
__all__ = ["SSEEmitter", "EventType"]
|
||||
BIN
src/streaming/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/streaming/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/streaming/__pycache__/sse.cpython-312.pyc
Normal file
BIN
src/streaming/__pycache__/sse.cpython-312.pyc
Normal file
Binary file not shown.
186
src/streaming/sse.py
Normal file
186
src/streaming/sse.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""SSE (Server-Sent Events) streaming system.
|
||||
|
||||
Supports structured event types, per-session event queues,
|
||||
and Redis-backed persistence.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from enum import StrEnum
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from ..config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventType(StrEnum):
|
||||
SESSION_CREATED = "session.created"
|
||||
EXECUTION_STARTED = "execution.started"
|
||||
AGENT_DELTA = "agent.delta"
|
||||
TOOL_STARTED = "tool.started"
|
||||
TOOL_COMPLETED = "tool.completed"
|
||||
SUBAGENT_ASSIGNED = "subagent.assigned"
|
||||
EXECUTION_COMPLETED = "execution.completed"
|
||||
ERROR = "error"
|
||||
KEEPALIVE = "keepalive"
|
||||
|
||||
|
||||
class SSEEvent:
|
||||
"""A single SSE event."""
|
||||
|
||||
__slots__ = ("event_type", "data", "session_id", "timestamp", "event_id")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_type: EventType,
|
||||
data: dict[str, Any],
|
||||
session_id: str,
|
||||
event_id: str = "",
|
||||
) -> None:
|
||||
self.event_type = event_type
|
||||
self.data = data
|
||||
self.session_id = session_id
|
||||
self.timestamp = datetime.now(timezone.utc).isoformat()
|
||||
self.event_id = event_id or f"{int(time.time() * 1000)}"
|
||||
|
||||
def format_sse(self) -> str:
|
||||
"""Format as SSE wire protocol."""
|
||||
payload = {
|
||||
"type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
lines = [
|
||||
f"id: {self.event_id}",
|
||||
f"event: {self.event_type.value}",
|
||||
f"data: {json.dumps(payload)}",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"event_id": self.event_id,
|
||||
"type": self.event_type.value,
|
||||
"data": self.data,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
class SSEEmitter:
|
||||
"""Manages per-session SSE event queues with Redis persistence."""
|
||||
|
||||
def __init__(self, redis_storage=None) -> None:
|
||||
self._queues: dict[str, list[asyncio.Queue[SSEEvent | None]]] = {}
|
||||
self._history: dict[str, list[SSEEvent]] = {}
|
||||
self._max_history = 500
|
||||
self._storage = redis_storage
|
||||
|
||||
def set_storage(self, redis_storage) -> None:
|
||||
"""Set the Redis storage backend (called after startup)."""
|
||||
self._storage = redis_storage
|
||||
|
||||
async def emit(
|
||||
self,
|
||||
event_type: EventType,
|
||||
data: dict[str, Any],
|
||||
session_id: str,
|
||||
) -> None:
|
||||
"""Emit an event to all listeners of a session."""
|
||||
event = SSEEvent(
|
||||
event_type=event_type,
|
||||
data=data,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Store in memory
|
||||
if session_id not in self._history:
|
||||
self._history[session_id] = []
|
||||
history = self._history[session_id]
|
||||
history.append(event)
|
||||
if len(history) > self._max_history:
|
||||
self._history[session_id] = history[-self._max_history:]
|
||||
|
||||
# Persist to Redis (skip keepalives — they're noise)
|
||||
if self._storage and event_type != EventType.KEEPALIVE:
|
||||
try:
|
||||
await self._storage.append_event(session_id, event.to_dict())
|
||||
logger.debug("Persisted event %s for session %s", event_type.value, session_id[:8])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist event to Redis: %s — storage type: %s", e, type(self._storage).__name__)
|
||||
|
||||
# Push to all active queues for this session
|
||||
queues = self._queues.get(session_id, [])
|
||||
for q in queues:
|
||||
try:
|
||||
q.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("SSE queue full for session %s", session_id)
|
||||
|
||||
async def subscribe(
|
||||
self, session_id: str
|
||||
) -> AsyncIterator[str]:
|
||||
"""Subscribe to SSE events for a session.
|
||||
|
||||
Yields formatted SSE strings. Sends keepalives to detect
|
||||
disconnected clients.
|
||||
"""
|
||||
queue: asyncio.Queue[SSEEvent | None] = asyncio.Queue(maxsize=256)
|
||||
|
||||
if session_id not in self._queues:
|
||||
self._queues[session_id] = []
|
||||
self._queues[session_id].append(queue)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
event = await asyncio.wait_for(
|
||||
queue.get(),
|
||||
timeout=settings.sse_keepalive_seconds,
|
||||
)
|
||||
if event is None:
|
||||
break
|
||||
yield event.format_sse()
|
||||
except asyncio.TimeoutError:
|
||||
# Send keepalive
|
||||
keepalive = SSEEvent(
|
||||
event_type=EventType.KEEPALIVE,
|
||||
data={},
|
||||
session_id=session_id,
|
||||
)
|
||||
yield keepalive.format_sse()
|
||||
finally:
|
||||
if queue in self._queues.get(session_id, []):
|
||||
self._queues[session_id].remove(queue)
|
||||
|
||||
async def get_history(self, session_id: str) -> list[dict[str, Any]]:
|
||||
"""Return event history — from Redis if available, fallback to memory."""
|
||||
# Try Redis first (persistent)
|
||||
if self._storage:
|
||||
try:
|
||||
events = await self._storage.get_events(session_id)
|
||||
if events:
|
||||
return events
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read events from Redis: %s", e)
|
||||
|
||||
# Fallback to in-memory
|
||||
events = self._history.get(session_id, [])
|
||||
return [e.to_dict() for e in events]
|
||||
|
||||
def cleanup_session(self, session_id: str) -> None:
|
||||
"""Remove all queues and in-memory history for a session."""
|
||||
for q in self._queues.get(session_id, []):
|
||||
try:
|
||||
q.put_nowait(None)
|
||||
except asyncio.QueueFull:
|
||||
pass
|
||||
self._queues.pop(session_id, None)
|
||||
self._history.pop(session_id, None)
|
||||
Reference in New Issue
Block a user