Files
agenticSystem/src/memory/store.py
2026-04-01 23:16:45 +01:00

191 lines
6.8 KiB
Python

"""Persistent memory store backed by Redis.
Supports three memory tiers:
1. Rules/documents — persistent, always loaded
2. Artifact summaries — per-session, loaded on demand
3. Optional embeddings — for semantic search
"""
from __future__ import annotations
import json
import logging
from typing import Any
import redis.asyncio as redis
from ..config import settings
from ..models.artifacts import ArtifactSummary
from ..models.context import MemoryDocument, MemoryType
logger = logging.getLogger(__name__)
class MemoryStore:
"""Async memory store with Redis backend."""
def __init__(self, redis_client: redis.Redis) -> None:
self._r = redis_client
self._prefix = settings.redis_key_prefix
# ------------------------------------------------------------------
# Key helpers
# ------------------------------------------------------------------
def _key(self, *parts: str) -> str:
return ":".join([self._prefix, *parts])
# ------------------------------------------------------------------
# Rules & documents (persistent memory)
# ------------------------------------------------------------------
async def store_document(self, doc: MemoryDocument) -> None:
key = self._key("memory", doc.namespace, doc.memory_id)
await self._r.set(key, doc.model_dump_json())
# Index by namespace
await self._r.sadd(self._key("memory", doc.namespace, "_index"), doc.memory_id)
# Index by type
await self._r.sadd(
self._key("memory", "_type", doc.memory_type.value), doc.memory_id
)
# Index by tags
for tag in doc.tags:
await self._r.sadd(self._key("memory", "_tag", tag), doc.memory_id)
async def get_document(
self, memory_id: str, namespace: str = "global"
) -> MemoryDocument | None:
key = self._key("memory", namespace, memory_id)
data = await self._r.get(key)
if data:
return MemoryDocument.model_validate_json(data)
return None
async def list_documents(
self,
namespace: str = "global",
memory_type: MemoryType | None = None,
tags: list[str] | None = None,
) -> list[MemoryDocument]:
"""List documents with optional type/tag filters."""
# Start with namespace index
ids = await self._r.smembers(self._key("memory", namespace, "_index"))
id_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in ids}
# Intersect with type filter
if memory_type:
type_ids = await self._r.smembers(
self._key("memory", "_type", memory_type.value)
)
type_set = {mid.decode() if isinstance(mid, bytes) else mid for mid in type_ids}
id_set &= type_set
# Intersect with tag filter
if tags:
for tag in tags:
tag_ids = await self._r.smembers(self._key("memory", "_tag", tag))
tag_set = {
mid.decode() if isinstance(mid, bytes) else mid for mid in tag_ids
}
id_set &= tag_set
docs: list[MemoryDocument] = []
for mid in id_set:
doc = await self.get_document(mid, namespace)
if doc:
docs.append(doc)
return docs
async def delete_document(
self, memory_id: str, namespace: str = "global"
) -> bool:
key = self._key("memory", namespace, memory_id)
deleted = await self._r.delete(key)
await self._r.srem(self._key("memory", namespace, "_index"), memory_id)
return bool(deleted)
# ------------------------------------------------------------------
# Artifact summaries (per-session)
# ------------------------------------------------------------------
async def store_artifact(
self, session_id: str, artifact: ArtifactSummary
) -> None:
key = self._key("session", session_id, "artifacts")
await self._r.hset(key, artifact.artifact_id, artifact.model_dump_json())
await self._r.expire(key, settings.session_ttl_seconds)
async def get_artifact(
self, session_id: str, artifact_id: str
) -> ArtifactSummary | None:
key = self._key("session", session_id, "artifacts")
data = await self._r.hget(key, artifact_id)
if data:
return ArtifactSummary.model_validate_json(data)
return None
async def list_artifacts(self, session_id: str) -> list[ArtifactSummary]:
key = self._key("session", session_id, "artifacts")
all_data = await self._r.hgetall(key)
return [
ArtifactSummary.model_validate_json(v) for v in all_data.values()
]
# ------------------------------------------------------------------
# Optional embeddings
# ------------------------------------------------------------------
async def store_embedding(
self, memory_id: str, embedding: list[float], namespace: str = "global"
) -> None:
"""Store an embedding vector for a memory document."""
key = self._key("embeddings", namespace, memory_id)
await self._r.set(key, json.dumps(embedding))
async def get_embedding(
self, memory_id: str, namespace: str = "global"
) -> list[float] | None:
key = self._key("embeddings", namespace, memory_id)
data = await self._r.get(key)
if data:
return json.loads(data)
return None
async def search_by_similarity(
self,
query_embedding: list[float],
namespace: str = "global",
top_k: int = 5,
) -> list[tuple[str, float]]:
"""Brute-force cosine similarity search over stored embeddings.
For production, swap this with Redis Vector Search (RediSearch)
or a dedicated vector DB.
"""
pattern = self._key("embeddings", namespace, "*")
results: list[tuple[str, float]] = []
async for key in self._r.scan_iter(match=pattern, count=100):
key_str = key.decode() if isinstance(key, bytes) else key
memory_id = key_str.rsplit(":", 1)[-1]
data = await self._r.get(key)
if not data:
continue
stored = json.loads(data)
score = self._cosine_similarity(query_embedding, stored)
results.append((memory_id, score))
results.sort(key=lambda x: x[1], reverse=True)
return results[:top_k]
@staticmethod
def _cosine_similarity(a: list[float], b: list[float]) -> float:
if len(a) != len(b) or not a:
return 0.0
dot = sum(x * y for x, y in zip(a, b))
mag_a = sum(x * x for x in a) ** 0.5
mag_b = sum(x * x for x in b) ** 0.5
if mag_a == 0 or mag_b == 0:
return 0.0
return dot / (mag_a * mag_b)