Initial commit
This commit is contained in:
190
src/memory/store.py
Normal file
190
src/memory/store.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""Persistent memory store backed by Redis.
|
||||
|
||||
Supports three memory tiers:
|
||||
1. Rules/documents — persistent, always loaded
|
||||
2. Artifact summaries — per-session, loaded on demand
|
||||
3. Optional embeddings — for semantic search
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from ..config import settings
|
||||
from ..models.artifacts import ArtifactSummary
|
||||
from ..models.context import MemoryDocument, MemoryType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MemoryStore:
|
||||
"""Async memory store with Redis backend."""
|
||||
|
||||
def __init__(self, redis_client: redis.Redis) -> None:
|
||||
self._r = redis_client
|
||||
self._prefix = settings.redis_key_prefix
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Key helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _key(self, *parts: str) -> str:
|
||||
return ":".join([self._prefix, *parts])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Rules & documents (persistent memory)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_document(self, doc: MemoryDocument) -> None:
|
||||
key = self._key("memory", doc.namespace, doc.memory_id)
|
||||
await self._r.set(key, doc.model_dump_json())
|
||||
# Index by namespace
|
||||
await self._r.sadd(self._key("memory", doc.namespace, "_index"), doc.memory_id)
|
||||
# Index by type
|
||||
await self._r.sadd(
|
||||
self._key("memory", "_type", doc.memory_type.value), doc.memory_id
|
||||
)
|
||||
# Index by tags
|
||||
for tag in doc.tags:
|
||||
await self._r.sadd(self._key("memory", "_tag", tag), doc.memory_id)
|
||||
|
||||
async def get_document(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> MemoryDocument | None:
|
||||
key = self._key("memory", namespace, memory_id)
|
||||
data = await self._r.get(key)
|
||||
if data:
|
||||
return MemoryDocument.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def list_documents(
|
||||
self,
|
||||
namespace: str = "global",
|
||||
memory_type: MemoryType | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> list[MemoryDocument]:
|
||||
"""List documents with optional type/tag filters."""
|
||||
# Start with namespace index
|
||||
ids = await self._r.smembers(self._key("memory", namespace, "_index"))
|
||||
id_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in ids}
|
||||
|
||||
# Intersect with type filter
|
||||
if memory_type:
|
||||
type_ids = await self._r.smembers(
|
||||
self._key("memory", "_type", memory_type.value)
|
||||
)
|
||||
type_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in type_ids}
|
||||
id_set &= type_set
|
||||
|
||||
# Intersect with tag filter
|
||||
if tags:
|
||||
for tag in tags:
|
||||
tag_ids = await self._r.smembers(self._key("memory", "_tag", tag))
|
||||
tag_set = {
|
||||
mid.decode() if isinstance(mid, bytes) else mid for mid in tag_ids
|
||||
}
|
||||
id_set &= tag_set
|
||||
|
||||
docs: list[MemoryDocument] = []
|
||||
for mid in id_set:
|
||||
doc = await self.get_document(mid, namespace)
|
||||
if doc:
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
async def delete_document(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> bool:
|
||||
key = self._key("memory", namespace, memory_id)
|
||||
deleted = await self._r.delete(key)
|
||||
await self._r.srem(self._key("memory", namespace, "_index"), memory_id)
|
||||
return bool(deleted)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Artifact summaries (per-session)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_artifact(
|
||||
self, session_id: str, artifact: ArtifactSummary
|
||||
) -> None:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
await self._r.hset(key, artifact.artifact_id, artifact.model_dump_json())
|
||||
await self._r.expire(key, settings.session_ttl_seconds)
|
||||
|
||||
async def get_artifact(
|
||||
self, session_id: str, artifact_id: str
|
||||
) -> ArtifactSummary | None:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
data = await self._r.hget(key, artifact_id)
|
||||
if data:
|
||||
return ArtifactSummary.model_validate_json(data)
|
||||
return None
|
||||
|
||||
async def list_artifacts(self, session_id: str) -> list[ArtifactSummary]:
|
||||
key = self._key("session", session_id, "artifacts")
|
||||
all_data = await self._r.hgetall(key)
|
||||
return [
|
||||
ArtifactSummary.model_validate_json(v) for v in all_data.values()
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Optional embeddings
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def store_embedding(
|
||||
self, memory_id: str, embedding: list[float], namespace: str = "global"
|
||||
) -> None:
|
||||
"""Store an embedding vector for a memory document."""
|
||||
key = self._key("embeddings", namespace, memory_id)
|
||||
await self._r.set(key, json.dumps(embedding))
|
||||
|
||||
async def get_embedding(
|
||||
self, memory_id: str, namespace: str = "global"
|
||||
) -> list[float] | None:
|
||||
key = self._key("embeddings", namespace, memory_id)
|
||||
data = await self._r.get(key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
return None
|
||||
|
||||
async def search_by_similarity(
|
||||
self,
|
||||
query_embedding: list[float],
|
||||
namespace: str = "global",
|
||||
top_k: int = 5,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Brute-force cosine similarity search over stored embeddings.
|
||||
|
||||
For production, swap this with Redis Vector Search (RediSearch)
|
||||
or a dedicated vector DB.
|
||||
"""
|
||||
pattern = self._key("embeddings", namespace, "*")
|
||||
results: list[tuple[str, float]] = []
|
||||
|
||||
async for key in self._r.scan_iter(match=pattern, count=100):
|
||||
key_str = key.decode() if isinstance(key, bytes) else key
|
||||
memory_id = key_str.rsplit(":", 1)[-1]
|
||||
data = await self._r.get(key)
|
||||
if not data:
|
||||
continue
|
||||
stored = json.loads(data)
|
||||
score = self._cosine_similarity(query_embedding, stored)
|
||||
results.append((memory_id, score))
|
||||
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
return results[:top_k]
|
||||
|
||||
@staticmethod
|
||||
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
if len(a) != len(b) or not a:
|
||||
return 0.0
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = sum(x * x for x in a) ** 0.5
|
||||
mag_b = sum(x * x for x in b) ** 0.5
|
||||
if mag_a == 0 or mag_b == 0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
Reference in New Issue
Block a user