Initial commit
This commit is contained in:
5
src/adapters/__init__.py
Normal file
5
src/adapters/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from .base import ModelAdapter, ModelResponse, StreamChunk
|
||||
from .claude_adapter import ClaudeAdapter
|
||||
from .openai_adapter import OpenAIAdapter
|
||||
|
||||
__all__ = ["ModelAdapter", "ModelResponse", "StreamChunk", "ClaudeAdapter", "OpenAIAdapter"]
|
||||
BIN
src/adapters/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/base.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/base.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/claude_adapter.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/claude_adapter.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/adapters/__pycache__/openai_adapter.cpython-312.pyc
Normal file
BIN
src/adapters/__pycache__/openai_adapter.cpython-312.pyc
Normal file
Binary file not shown.
73
src/adapters/base.py
Normal file
73
src/adapters/base.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Model adapter interface — extensible for any LLM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""A single chunk from a streaming model response."""
|
||||
|
||||
delta: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_name: str = ""
|
||||
tool_arguments: str = ""
|
||||
finish_reason: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelResponse:
|
||||
"""Complete (non-streaming) model response."""
|
||||
|
||||
content: str = ""
|
||||
tool_calls: list[dict[str, Any]] = field(default_factory=list)
|
||||
finish_reason: str = ""
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
raw: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Per-call configuration."""
|
||||
|
||||
model_id: str = ""
|
||||
max_tokens: int = 4096
|
||||
temperature: float = 0.3
|
||||
stop_sequences: list[str] = field(default_factory=list)
|
||||
extra: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class ModelAdapter(ABC):
|
||||
"""Abstract interface for LLM providers.
|
||||
|
||||
Implementors must provide both streaming and non-streaming methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
"""Stream model response chunks."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
"""Get a complete model response (non-streaming)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
"""Estimate token count for the given text."""
|
||||
...
|
||||
201
src/adapters/claude_adapter.py
Normal file
201
src/adapters/claude_adapter.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""Claude/Anthropic model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
import anthropic
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClaudeAdapter(ModelAdapter):
|
||||
"""Adapter for the Anthropic Claude API."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key or settings.anthropic_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
# Separate system message
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
|
||||
async for event in stream:
|
||||
if event.type == "content_block_start":
|
||||
block = event.content_block
|
||||
if block.type == "tool_use":
|
||||
current_tool_id = block.id
|
||||
current_tool_name = block.name
|
||||
accumulated_args = ""
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_delta":
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
yield StreamChunk(delta=delta.text)
|
||||
elif delta.type == "input_json_delta":
|
||||
accumulated_args += delta.partial_json
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=delta.partial_json,
|
||||
)
|
||||
continue
|
||||
|
||||
if event.type == "content_block_stop":
|
||||
if current_tool_id and accumulated_args:
|
||||
yield StreamChunk(
|
||||
tool_call_id=current_tool_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_arguments=accumulated_args,
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
current_tool_id = ""
|
||||
current_tool_name = ""
|
||||
accumulated_args = ""
|
||||
continue
|
||||
|
||||
if event.type == "message_delta":
|
||||
yield StreamChunk(
|
||||
finish_reason=event.delta.stop_reason or "",
|
||||
usage={
|
||||
"output_tokens": getattr(
|
||||
event.usage, "output_tokens", 0
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
system_content = ""
|
||||
api_messages: list[dict[str, Any]] = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
system_content = m["content"]
|
||||
else:
|
||||
api_messages.append(m)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or settings.default_model_id,
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": api_messages,
|
||||
}
|
||||
if system_content:
|
||||
kwargs["system"] = system_content
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.messages.create(**kwargs)
|
||||
|
||||
content = ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content += block.text
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"arguments": block.input,
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=response.stop_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.input_tokens,
|
||||
"output_tokens": response.usage.output_tokens,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to Anthropic tool format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("input_schema", tool.get("parameters", {"type": "object"})),
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
197
src/adapters/openai_adapter.py
Normal file
197
src/adapters/openai_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""OpenAI model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""Adapter for the OpenAI API (GPT-4o, o1, etc.)."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or settings.openai_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# Text content
|
||||
if delta and delta.content:
|
||||
yield StreamChunk(delta=delta.content)
|
||||
|
||||
# Tool calls
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc.id or "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
yield StreamChunk(
|
||||
tool_call_id=tc.id or tool_calls_acc[idx]["id"],
|
||||
tool_name=tc.function.name,
|
||||
)
|
||||
if tc.function and tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
yield StreamChunk(
|
||||
tool_call_id=tool_calls_acc[idx]["id"],
|
||||
tool_name=tool_calls_acc[idx]["name"],
|
||||
tool_arguments=tc.function.arguments,
|
||||
)
|
||||
|
||||
# Finish
|
||||
if choice.finish_reason:
|
||||
if choice.finish_reason == "tool_calls":
|
||||
for acc in tool_calls_acc.values():
|
||||
yield StreamChunk(
|
||||
tool_call_id=acc["id"],
|
||||
tool_name=acc["name"],
|
||||
tool_arguments=acc["arguments"],
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
else:
|
||||
yield StreamChunk(
|
||||
finish_reason="end_turn"
|
||||
if choice.finish_reason == "stop"
|
||||
else choice.finish_reason,
|
||||
usage={
|
||||
"output_tokens": chunk.usage.completion_tokens
|
||||
if chunk.usage
|
||||
else 0
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.chat.completions.create(**kwargs)
|
||||
choice = response.choices[0]
|
||||
|
||||
content = choice.message.content or ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
if choice.message.tool_calls:
|
||||
for tc in choice.message.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments)
|
||||
if tc.function.arguments
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"output_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to OpenAI function calling format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get(
|
||||
"input_schema", tool.get("parameters", {"type": "object"})
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
Reference in New Issue
Block a user