424 lines
16 KiB
Python
424 lines
16 KiB
Python
"""Context compaction: extract facts, remove redundancy, maintain constraints.
|
|
|
|
The compactor is responsible for keeping the context within token budget
|
|
while preserving the most important information.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
from ..models.artifacts import ArtifactSummary
|
|
from ..models.context import ContextSection, ContextSectionType
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Token counting with tiktoken ---
|
|
try:
|
|
import tiktoken
|
|
_encoding = tiktoken.get_encoding("cl100k_base") # Works for GPT-4o and Claude
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Accurate token count using tiktoken (cl100k_base encoding)."""
|
|
if not text:
|
|
return 0
|
|
return len(_encoding.encode(text, disallowed_special=()))
|
|
|
|
logger.info("Using tiktoken for accurate token counting")
|
|
except ImportError:
|
|
def estimate_tokens(text: str) -> int:
|
|
"""Fallback: ~4 chars per token."""
|
|
return max(1, len(text) // 4)
|
|
|
|
logger.warning("tiktoken not installed — using approximate token counting")
|
|
|
|
|
|
class ContextCompactor:
|
|
"""Compacts context sections to fit within token budgets."""
|
|
|
|
def __init__(self, max_tokens: int = 120_000) -> None:
|
|
self.max_tokens = max_tokens
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public
|
|
# ------------------------------------------------------------------
|
|
|
|
def compact_sections(
|
|
self,
|
|
sections: list[ContextSection],
|
|
max_tokens: int | None = None,
|
|
) -> tuple[list[ContextSection], dict[str, Any]]:
|
|
"""Remove redundancy and trim low-priority sections to fit budget."""
|
|
budget = max_tokens if max_tokens is not None else self.max_tokens
|
|
original_count = len(sections)
|
|
|
|
# 1. Deduplicate identical content across sections
|
|
sections = self._deduplicate(sections)
|
|
duplicates_removed = original_count - len(sections)
|
|
|
|
# 2. Estimate tokens per section
|
|
for s in sections:
|
|
s.token_estimate = estimate_tokens(s.content)
|
|
|
|
total = sum(s.token_estimate for s in sections)
|
|
meta = {
|
|
"budget_tokens": budget,
|
|
"input_tokens": total,
|
|
"output_tokens": total,
|
|
"sections_input": original_count,
|
|
"sections_output": len(sections),
|
|
"duplicates_removed": duplicates_removed,
|
|
"sections_compacted": 0,
|
|
"sections_removed": 0,
|
|
}
|
|
if total <= budget:
|
|
return sections, meta
|
|
|
|
# 3. Sort by priority (highest first) — immutable_rules never trimmed
|
|
sections.sort(key=lambda s: s.priority, reverse=True)
|
|
|
|
# 4. Progressively trim lowest-priority sections
|
|
while total > budget and sections:
|
|
lowest = sections[-1]
|
|
if lowest.section_type == ContextSectionType.IMMUTABLE_RULES:
|
|
break # Never trim rules
|
|
# Try to compact the section first
|
|
compacted = self._compact_text(lowest.content)
|
|
new_estimate = estimate_tokens(compacted)
|
|
saved = lowest.token_estimate - new_estimate
|
|
if saved > 0:
|
|
lowest.content = compacted
|
|
lowest.token_estimate = new_estimate
|
|
total -= saved
|
|
meta["sections_compacted"] += 1
|
|
else:
|
|
# Remove the section entirely
|
|
total -= lowest.token_estimate
|
|
sections.pop()
|
|
meta["sections_removed"] += 1
|
|
|
|
meta["output_tokens"] = total
|
|
meta["sections_output"] = len(sections)
|
|
return sections, meta
|
|
|
|
def summarize_tool_output(
|
|
self,
|
|
tool_name: str,
|
|
raw_output: str,
|
|
session_id: str,
|
|
task_id: str,
|
|
) -> ArtifactSummary:
|
|
"""Summarise raw tool output into an ArtifactSummary.
|
|
|
|
The raw output is NEVER passed through to the model context.
|
|
"""
|
|
facts = self._extract_facts(raw_output)
|
|
summary = self._build_summary(tool_name, raw_output, facts)
|
|
|
|
artifact_type = self._infer_artifact_type(tool_name)
|
|
artifact_id = hashlib.sha256(
|
|
f"{session_id}:{task_id}:{tool_name}:{raw_output[:200]}".encode()
|
|
).hexdigest()[:16]
|
|
|
|
return ArtifactSummary(
|
|
artifact_id=artifact_id,
|
|
session_id=session_id,
|
|
task_id=task_id,
|
|
artifact_type=artifact_type,
|
|
title=f"Output of {tool_name}",
|
|
summary=summary,
|
|
facts=facts,
|
|
source_tool=tool_name,
|
|
char_count=len(raw_output),
|
|
)
|
|
|
|
def compact_artifact_summaries(
|
|
self, summaries: list[ArtifactSummary], max_chars: int = 2000
|
|
) -> str:
|
|
"""Merge multiple artifact summaries into a single compact block."""
|
|
if not summaries:
|
|
return ""
|
|
|
|
lines: list[str] = ["## Artifacts"]
|
|
budget = max_chars - 20
|
|
for art in summaries:
|
|
entry = f"- [{art.artifact_type}] {art.title}: {art.summary}"
|
|
if art.facts:
|
|
entry += " | Facts: " + "; ".join(art.facts[:3])
|
|
if len(entry) > budget:
|
|
entry = entry[:budget] + "…"
|
|
lines.append(entry)
|
|
budget -= len(entry)
|
|
if budget <= 0:
|
|
lines.append(f" … and {len(summaries) - len(lines) + 1} more artifacts")
|
|
break
|
|
return "\n".join(lines)
|
|
|
|
def compact_conversation(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
max_tokens: int,
|
|
recent_raw_limit: int = 2,
|
|
raw_char_limit: int = 2000,
|
|
) -> tuple[list[dict[str, Any]], dict[str, Any]]:
|
|
"""Compact conversation history while preserving the latest user turn."""
|
|
total = sum(self._estimate_message_tokens(m) for m in messages)
|
|
meta = {
|
|
"budget_tokens": max_tokens,
|
|
"input_tokens": total,
|
|
"output_tokens": total,
|
|
"messages_input": len(messages),
|
|
"messages_output": len(messages),
|
|
"messages_compacted": 0,
|
|
"tool_messages_compacted": 0,
|
|
"assistant_messages_compacted": 0,
|
|
"user_messages_compacted": 0,
|
|
"raw_tool_results_kept": 0,
|
|
}
|
|
if total <= max_tokens:
|
|
return messages, meta
|
|
|
|
compacted = [dict(m) for m in messages]
|
|
last_user_idx = max(
|
|
(i for i, m in enumerate(compacted) if m.get("role") == "user"),
|
|
default=-1,
|
|
)
|
|
tool_indexes = [i for i, m in enumerate(compacted) if m.get("role") == "tool"]
|
|
keep_raw_tool_indexes = (
|
|
set(tool_indexes[-recent_raw_limit:])
|
|
if recent_raw_limit > 0
|
|
else set()
|
|
)
|
|
|
|
for idx in keep_raw_tool_indexes:
|
|
content = compacted[idx].get("content", "")
|
|
if isinstance(content, str) and content:
|
|
truncated = content[:raw_char_limit]
|
|
if truncated != content:
|
|
compacted[idx]["content"] = truncated
|
|
meta["messages_compacted"] += 1
|
|
meta["tool_messages_compacted"] += 1
|
|
meta["raw_tool_results_kept"] += 1
|
|
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total > max_tokens:
|
|
for idx in tool_indexes:
|
|
if idx in keep_raw_tool_indexes:
|
|
continue
|
|
content = compacted[idx].get("content", "")
|
|
if not isinstance(content, str) or not content:
|
|
continue
|
|
compacted[idx]["content"] = self._summarize_message_content(
|
|
content,
|
|
prefix="[TOOL RESULT COMPACTADO]",
|
|
max_chars=max(180, raw_char_limit // 4),
|
|
)
|
|
meta["messages_compacted"] += 1
|
|
meta["tool_messages_compacted"] += 1
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total <= max_tokens:
|
|
break
|
|
|
|
if total > max_tokens:
|
|
for idx, message in enumerate(compacted):
|
|
if idx == last_user_idx or message.get("role") != "assistant":
|
|
continue
|
|
content = message.get("content", "")
|
|
if not isinstance(content, str) or not content:
|
|
continue
|
|
message["content"] = self._summarize_message_content(
|
|
content,
|
|
prefix="[ASSISTANT COMPACTADO]",
|
|
max_chars=max(240, raw_char_limit // 3),
|
|
)
|
|
meta["messages_compacted"] += 1
|
|
meta["assistant_messages_compacted"] += 1
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total <= max_tokens:
|
|
break
|
|
|
|
if total > max_tokens:
|
|
for idx, message in enumerate(compacted):
|
|
if idx == last_user_idx or message.get("role") != "user":
|
|
continue
|
|
content = message.get("content", "")
|
|
if not isinstance(content, str) or not content:
|
|
continue
|
|
message["content"] = self._summarize_message_content(
|
|
content,
|
|
prefix="[USER CONTEXT COMPACTADO]",
|
|
max_chars=max(220, raw_char_limit // 3),
|
|
)
|
|
meta["messages_compacted"] += 1
|
|
meta["user_messages_compacted"] += 1
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total <= max_tokens:
|
|
break
|
|
|
|
if total > max_tokens:
|
|
for idx in tool_indexes:
|
|
if idx in keep_raw_tool_indexes:
|
|
compacted[idx]["content"] = self._summarize_message_content(
|
|
compacted[idx].get("content", ""),
|
|
prefix="[TOOL RESULT COMPACTADO]",
|
|
max_chars=max(180, raw_char_limit // 5),
|
|
)
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total <= max_tokens:
|
|
break
|
|
|
|
if total > max_tokens:
|
|
for idx, message in enumerate(compacted):
|
|
if idx == last_user_idx:
|
|
continue
|
|
role = message.get("role", "")
|
|
content = message.get("content", "")
|
|
if not isinstance(content, str) or not content:
|
|
continue
|
|
if role == "tool":
|
|
message["content"] = "[TOOL RESULT COMPACTADO]"
|
|
elif role == "assistant":
|
|
message["content"] = "[ASSISTANT COMPACTADO]"
|
|
elif role == "user":
|
|
message["content"] = "[USER CONTEXT COMPACTADO]"
|
|
total = sum(self._estimate_message_tokens(m) for m in compacted)
|
|
if total <= max_tokens:
|
|
break
|
|
|
|
meta["output_tokens"] = total
|
|
return compacted, meta
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internals
|
|
# ------------------------------------------------------------------
|
|
|
|
def _deduplicate(
|
|
self, sections: list[ContextSection]
|
|
) -> list[ContextSection]:
|
|
seen: set[str] = set()
|
|
result: list[str] = []
|
|
unique: list[ContextSection] = []
|
|
for s in sections:
|
|
h = hashlib.md5(s.content.encode()).hexdigest()
|
|
if h not in seen:
|
|
seen.add(h)
|
|
unique.append(s)
|
|
return unique
|
|
|
|
def _compact_text(self, text: str) -> str:
|
|
"""Aggressively compact text: remove blank lines, collapse whitespace."""
|
|
lines = text.splitlines()
|
|
# Remove empty or whitespace-only lines
|
|
lines = [l.rstrip() for l in lines if l.strip()]
|
|
# Collapse consecutive duplicate lines
|
|
compacted: list[str] = []
|
|
for line in lines:
|
|
if not compacted or line != compacted[-1]:
|
|
compacted.append(line)
|
|
return "\n".join(compacted)
|
|
|
|
def _summarize_message_content(
|
|
self,
|
|
content: str,
|
|
prefix: str,
|
|
max_chars: int,
|
|
) -> str:
|
|
stripped = content.strip()
|
|
compacted = self._compact_text(content)
|
|
if len(compacted) <= max_chars:
|
|
if compacted != stripped:
|
|
summary = f"{prefix} {compacted}".strip()
|
|
if len(summary) > max_chars:
|
|
summary = summary[:max_chars].rstrip() + "…"
|
|
return summary
|
|
return compacted
|
|
|
|
lines = [l.strip() for l in compacted.splitlines() if l.strip()]
|
|
if not lines:
|
|
return prefix
|
|
if len(lines) == 1:
|
|
return f"{prefix} {lines[0][:max_chars]}".strip()
|
|
|
|
first = lines[0][: max_chars // 2]
|
|
last = lines[-1][: max_chars // 3]
|
|
summary = f"{prefix} First: {first}"
|
|
if last and last != first:
|
|
summary += f" | Last: {last}"
|
|
if len(summary) > max_chars:
|
|
summary = summary[:max_chars].rstrip() + "…"
|
|
return summary
|
|
|
|
@staticmethod
|
|
def _estimate_message_tokens(message: dict[str, Any]) -> int:
|
|
content = message.get("content", "")
|
|
tokens = estimate_tokens(content if isinstance(content, str) else str(content))
|
|
if message.get("tool_calls"):
|
|
tokens += estimate_tokens(json.dumps(message.get("tool_calls", []), ensure_ascii=False))
|
|
return tokens
|
|
|
|
def _extract_facts(self, raw_output: str) -> list[str]:
|
|
"""Extract short factual claims from tool output."""
|
|
facts: list[str] = []
|
|
lines = raw_output.strip().splitlines()
|
|
|
|
for line in lines[:100]: # Limit scan depth
|
|
line = line.strip()
|
|
if not line or len(line) < 10:
|
|
continue
|
|
# Lines that look like key-value facts
|
|
if re.match(r"^[\w\s]+:\s+.+", line) and len(line) < 200:
|
|
facts.append(line)
|
|
# Lines starting with status indicators
|
|
elif re.match(r"^(✓|✗|PASS|FAIL|ERROR|OK|INFO|WARNING)", line):
|
|
facts.append(line)
|
|
# Lines that contain file paths with results
|
|
elif re.match(r"^[\w/\\.]+\s*[:\-]\s*.+", line) and len(line) < 200:
|
|
facts.append(line)
|
|
|
|
# Deduplicate and limit
|
|
seen: set[str] = set()
|
|
unique: list[str] = []
|
|
for f in facts:
|
|
if f not in seen:
|
|
seen.add(f)
|
|
unique.append(f)
|
|
return unique[:15]
|
|
|
|
def _build_summary(
|
|
self, tool_name: str, raw_output: str, facts: list[str]
|
|
) -> str:
|
|
"""Build a concise summary from tool output."""
|
|
lines = raw_output.strip().splitlines()
|
|
total_lines = len(lines)
|
|
char_count = len(raw_output)
|
|
|
|
parts = [f"Tool '{tool_name}' returned {total_lines} lines ({char_count} chars)."]
|
|
|
|
if facts:
|
|
parts.append(f"Key findings: {'; '.join(facts[:5])}")
|
|
|
|
# Include first and last meaningful lines as bookends
|
|
meaningful = [l.strip() for l in lines if l.strip()]
|
|
if meaningful:
|
|
parts.append(f"First: {meaningful[0][:120]}")
|
|
if len(meaningful) > 1:
|
|
parts.append(f"Last: {meaningful[-1][:120]}")
|
|
|
|
return " ".join(parts)
|
|
|
|
def _infer_artifact_type(self, tool_name: str) -> str:
|
|
tool_lower = tool_name.lower()
|
|
if any(k in tool_lower for k in ("read", "file", "code", "write", "edit")):
|
|
return "code"
|
|
if any(k in tool_lower for k in ("test", "check", "lint", "validate")):
|
|
return "test_result"
|
|
if any(k in tool_lower for k in ("search", "find", "grep", "glob")):
|
|
return "analysis"
|
|
if any(k in tool_lower for k in ("plan", "design", "architect")):
|
|
return "plan"
|
|
return "general"
|