191 lines
6.8 KiB
Python
191 lines
6.8 KiB
Python
"""Persistent memory store backed by Redis.
|
|
|
|
Supports three memory tiers:
|
|
1. Rules/documents — persistent, always loaded
|
|
2. Artifact summaries — per-session, loaded on demand
|
|
3. Optional embeddings — for semantic search
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
import redis.asyncio as redis
|
|
|
|
from ..config import settings
|
|
from ..models.artifacts import ArtifactSummary
|
|
from ..models.context import MemoryDocument, MemoryType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryStore:
|
|
"""Async memory store with Redis backend."""
|
|
|
|
def __init__(self, redis_client: redis.Redis) -> None:
|
|
self._r = redis_client
|
|
self._prefix = settings.redis_key_prefix
|
|
|
|
# ------------------------------------------------------------------
|
|
# Key helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
def _key(self, *parts: str) -> str:
|
|
return ":".join([self._prefix, *parts])
|
|
|
|
# ------------------------------------------------------------------
|
|
# Rules & documents (persistent memory)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def store_document(self, doc: MemoryDocument) -> None:
|
|
key = self._key("memory", doc.namespace, doc.memory_id)
|
|
await self._r.set(key, doc.model_dump_json())
|
|
# Index by namespace
|
|
await self._r.sadd(self._key("memory", doc.namespace, "_index"), doc.memory_id)
|
|
# Index by type
|
|
await self._r.sadd(
|
|
self._key("memory", "_type", doc.memory_type.value), doc.memory_id
|
|
)
|
|
# Index by tags
|
|
for tag in doc.tags:
|
|
await self._r.sadd(self._key("memory", "_tag", tag), doc.memory_id)
|
|
|
|
async def get_document(
|
|
self, memory_id: str, namespace: str = "global"
|
|
) -> MemoryDocument | None:
|
|
key = self._key("memory", namespace, memory_id)
|
|
data = await self._r.get(key)
|
|
if data:
|
|
return MemoryDocument.model_validate_json(data)
|
|
return None
|
|
|
|
async def list_documents(
|
|
self,
|
|
namespace: str = "global",
|
|
memory_type: MemoryType | None = None,
|
|
tags: list[str] | None = None,
|
|
) -> list[MemoryDocument]:
|
|
"""List documents with optional type/tag filters."""
|
|
# Start with namespace index
|
|
ids = await self._r.smembers(self._key("memory", namespace, "_index"))
|
|
id_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in ids}
|
|
|
|
# Intersect with type filter
|
|
if memory_type:
|
|
type_ids = await self._r.smembers(
|
|
self._key("memory", "_type", memory_type.value)
|
|
)
|
|
type_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in type_ids}
|
|
id_set &= type_set
|
|
|
|
# Intersect with tag filter
|
|
if tags:
|
|
for tag in tags:
|
|
tag_ids = await self._r.smembers(self._key("memory", "_tag", tag))
|
|
tag_set = {
|
|
mid.decode() if isinstance(mid, bytes) else mid for mid in tag_ids
|
|
}
|
|
id_set &= tag_set
|
|
|
|
docs: list[MemoryDocument] = []
|
|
for mid in id_set:
|
|
doc = await self.get_document(mid, namespace)
|
|
if doc:
|
|
docs.append(doc)
|
|
return docs
|
|
|
|
async def delete_document(
|
|
self, memory_id: str, namespace: str = "global"
|
|
) -> bool:
|
|
key = self._key("memory", namespace, memory_id)
|
|
deleted = await self._r.delete(key)
|
|
await self._r.srem(self._key("memory", namespace, "_index"), memory_id)
|
|
return bool(deleted)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Artifact summaries (per-session)
|
|
# ------------------------------------------------------------------
|
|
|
|
async def store_artifact(
|
|
self, session_id: str, artifact: ArtifactSummary
|
|
) -> None:
|
|
key = self._key("session", session_id, "artifacts")
|
|
await self._r.hset(key, artifact.artifact_id, artifact.model_dump_json())
|
|
await self._r.expire(key, settings.session_ttl_seconds)
|
|
|
|
async def get_artifact(
|
|
self, session_id: str, artifact_id: str
|
|
) -> ArtifactSummary | None:
|
|
key = self._key("session", session_id, "artifacts")
|
|
data = await self._r.hget(key, artifact_id)
|
|
if data:
|
|
return ArtifactSummary.model_validate_json(data)
|
|
return None
|
|
|
|
async def list_artifacts(self, session_id: str) -> list[ArtifactSummary]:
|
|
key = self._key("session", session_id, "artifacts")
|
|
all_data = await self._r.hgetall(key)
|
|
return [
|
|
ArtifactSummary.model_validate_json(v) for v in all_data.values()
|
|
]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Optional embeddings
|
|
# ------------------------------------------------------------------
|
|
|
|
async def store_embedding(
|
|
self, memory_id: str, embedding: list[float], namespace: str = "global"
|
|
) -> None:
|
|
"""Store an embedding vector for a memory document."""
|
|
key = self._key("embeddings", namespace, memory_id)
|
|
await self._r.set(key, json.dumps(embedding))
|
|
|
|
async def get_embedding(
|
|
self, memory_id: str, namespace: str = "global"
|
|
) -> list[float] | None:
|
|
key = self._key("embeddings", namespace, memory_id)
|
|
data = await self._r.get(key)
|
|
if data:
|
|
return json.loads(data)
|
|
return None
|
|
|
|
async def search_by_similarity(
|
|
self,
|
|
query_embedding: list[float],
|
|
namespace: str = "global",
|
|
top_k: int = 5,
|
|
) -> list[tuple[str, float]]:
|
|
"""Brute-force cosine similarity search over stored embeddings.
|
|
|
|
For production, swap this with Redis Vector Search (RediSearch)
|
|
or a dedicated vector DB.
|
|
"""
|
|
pattern = self._key("embeddings", namespace, "*")
|
|
results: list[tuple[str, float]] = []
|
|
|
|
async for key in self._r.scan_iter(match=pattern, count=100):
|
|
key_str = key.decode() if isinstance(key, bytes) else key
|
|
memory_id = key_str.rsplit(":", 1)[-1]
|
|
data = await self._r.get(key)
|
|
if not data:
|
|
continue
|
|
stored = json.loads(data)
|
|
score = self._cosine_similarity(query_embedding, stored)
|
|
results.append((memory_id, score))
|
|
|
|
results.sort(key=lambda x: x[1], reverse=True)
|
|
return results[:top_k]
|
|
|
|
@staticmethod
|
|
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
if len(a) != len(b) or not a:
|
|
return 0.0
|
|
dot = sum(x * y for x, y in zip(a, b))
|
|
mag_a = sum(x * x for x in a) ** 0.5
|
|
mag_b = sum(x * x for x in b) ** 0.5
|
|
if mag_a == 0 or mag_b == 0:
|
|
return 0.0
|
|
return dot / (mag_a * mag_b)
|