Initial commit
This commit is contained in:
201
src/adapters/claude_adapter.py
Normal file
201
src/adapters/claude_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Claude/Anthropic model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import anthropic
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""Adapter for the Anthropic Claude API."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key or settings.anthropic_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
# Separate system message
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if block.type == "tool_use":
|
||||
current_tool_id = block.id
|
||||
current_tool_name = block.name
|
||||
accumulated_args = ""
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
yield StreamChunk(delta=delta.text)
|
||||
elif delta.type == "input_json_delta":
|
||||
accumulated_args += delta.partial_json
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=delta.partial_json,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_stop":
|
||||
if current_tool_id and accumulated_args:
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=accumulated_args,
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
continue
|
||||
|
||||
if event.type == "message_delta":
|
||||
yield StreamChunk(
|
||||
finish_reason=event.delta.stop_reason or "",
|
||||
usage={
|
||||
"output_tokens": getattr(
|
||||
event.usage, "output_tokens", 0
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.messages.create(**kwargs)
|
||||
|
||||
content = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content += block.text
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"arguments": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=response.stop_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to Anthropic tool format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("input_schema", tool.get("parameters", {"type": "object"})),
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
Reference in New Issue
Block a user