"""Persistent memory store backed by Redis. Supports three memory tiers: 1. Rules/documents — persistent, always loaded 2. Artifact summaries — per-session, loaded on demand 3. Optional embeddings — for semantic search """ from __future__ import annotations import json import logging from typing import Any import redis.asyncio as redis from ..config import settings from ..models.artifacts import ArtifactSummary from ..models.context import MemoryDocument, MemoryType logger = logging.getLogger(__name__) class MemoryStore: """Async memory store with Redis backend.""" def __init__(self, redis_client: redis.Redis) -> None: self._r = redis_client self._prefix = settings.redis_key_prefix # ------------------------------------------------------------------ # Key helpers # ------------------------------------------------------------------ def _key(self, *parts: str) -> str: return ":".join([self._prefix, *parts]) # ------------------------------------------------------------------ # Rules & documents (persistent memory) # ------------------------------------------------------------------ async def store_document(self, doc: MemoryDocument) -> None: key = self._key("memory", doc.namespace, doc.memory_id) await self._r.set(key, doc.model_dump_json()) # Index by namespace await self._r.sadd(self._key("memory", doc.namespace, "_index"), doc.memory_id) # Index by type await self._r.sadd( self._key("memory", "_type", doc.memory_type.value), doc.memory_id ) # Index by tags for tag in doc.tags: await self._r.sadd(self._key("memory", "_tag", tag), doc.memory_id) async def get_document( self, memory_id: str, namespace: str = "global" ) -> MemoryDocument | None: key = self._key("memory", namespace, memory_id) data = await self._r.get(key) if data: return MemoryDocument.model_validate_json(data) return None async def list_documents( self, namespace: str = "global", memory_type: MemoryType | None = None, tags: list[str] | None = None, ) -> list[MemoryDocument]: """List documents with optional type/tag filters.""" # Start with namespace index ids = await self._r.smembers(self._key("memory", namespace, "_index")) id_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in ids} # Intersect with type filter if memory_type: type_ids = await self._r.smembers( self._key("memory", "_type", memory_type.value) ) type_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in type_ids} id_set &= type_set # Intersect with tag filter if tags: for tag in tags: tag_ids = await self._r.smembers(self._key("memory", "_tag", tag)) tag_set = { mid.decode() if isinstance(mid, bytes) else mid for mid in tag_ids } id_set &= tag_set docs: list[MemoryDocument] = [] for mid in id_set: doc = await self.get_document(mid, namespace) if doc: docs.append(doc) return docs async def delete_document( self, memory_id: str, namespace: str = "global" ) -> bool: key = self._key("memory", namespace, memory_id) deleted = await self._r.delete(key) await self._r.srem(self._key("memory", namespace, "_index"), memory_id) return bool(deleted) # ------------------------------------------------------------------ # Artifact summaries (per-session) # ------------------------------------------------------------------ async def store_artifact( self, session_id: str, artifact: ArtifactSummary ) -> None: key = self._key("session", session_id, "artifacts") await self._r.hset(key, artifact.artifact_id, artifact.model_dump_json()) await self._r.expire(key, settings.session_ttl_seconds) async def get_artifact( self, session_id: str, artifact_id: str ) -> ArtifactSummary | None: key = self._key("session", session_id, "artifacts") data = await self._r.hget(key, artifact_id) if data: return ArtifactSummary.model_validate_json(data) return None async def list_artifacts(self, session_id: str) -> list[ArtifactSummary]: key = self._key("session", session_id, "artifacts") all_data = await self._r.hgetall(key) return [ ArtifactSummary.model_validate_json(v) for v in all_data.values() ] # ------------------------------------------------------------------ # Optional embeddings # ------------------------------------------------------------------ async def store_embedding( self, memory_id: str, embedding: list[float], namespace: str = "global" ) -> None: """Store an embedding vector for a memory document.""" key = self._key("embeddings", namespace, memory_id) await self._r.set(key, json.dumps(embedding)) async def get_embedding( self, memory_id: str, namespace: str = "global" ) -> list[float] | None: key = self._key("embeddings", namespace, memory_id) data = await self._r.get(key) if data: return json.loads(data) return None async def search_by_similarity( self, query_embedding: list[float], namespace: str = "global", top_k: int = 5, ) -> list[tuple[str, float]]: """Brute-force cosine similarity search over stored embeddings. For production, swap this with Redis Vector Search (RediSearch) or a dedicated vector DB. """ pattern = self._key("embeddings", namespace, "*") results: list[tuple[str, float]] = [] async for key in self._r.scan_iter(match=pattern, count=100): key_str = key.decode() if isinstance(key, bytes) else key memory_id = key_str.rsplit(":", 1)[-1] data = await self._r.get(key) if not data: continue stored = json.loads(data) score = self._cosine_similarity(query_embedding, stored) results.append((memory_id, score)) results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] @staticmethod def _cosine_similarity(a: list[float], b: list[float]) -> float: if len(a) != len(b) or not a: return 0.0 dot = sum(x * y for x, y in zip(a, b)) mag_a = sum(x * x for x in a) ** 0.5 mag_b = sum(x * x for x in b) ** 0.5 if mag_a == 0 or mag_b == 0: return 0.0 return dot / (mag_a * mag_b)