378 lines
16 KiB
Python
378 lines
16 KiB
Python
"""Claude/Anthropic model adapter with full streaming support."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any, AsyncIterator
|
|
|
|
import anthropic
|
|
|
|
from ..config import settings
|
|
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Errores transitorios del proxy del modelo (MiniMax/Anthropic). Reintentamos
|
|
# con backoff exponencial: 1s, 3s, 9s. 529 es overloaded_error de Anthropic;
|
|
# 429 rate-limit; 503 service unavailable.
|
|
_TRANSIENT_STATUSES = {429, 503, 529}
|
|
_RETRY_DELAYS = (1.0, 3.0, 9.0)
|
|
|
|
|
|
def _is_transient(exc: Exception) -> bool:
|
|
"""True si el error es seguro de reintentar (sobrecarga / red transitoria)."""
|
|
if isinstance(exc, (anthropic.APIConnectionError, anthropic.APITimeoutError)):
|
|
return True
|
|
if isinstance(exc, anthropic.APIStatusError):
|
|
status = getattr(exc, "status_code", None)
|
|
if status in _TRANSIENT_STATUSES:
|
|
return True
|
|
msg = str(exc).lower()
|
|
if "overloaded" in msg or "high load" in msg:
|
|
return True
|
|
return False
|
|
|
|
|
|
class ClaudeAdapter(ModelAdapter):
|
|
"""Adapter for the Anthropic Claude API."""
|
|
|
|
def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None:
|
|
kwargs: dict[str, Any] = {
|
|
"api_key": api_key or settings.anthropic_api_key,
|
|
}
|
|
url = base_url or settings.anthropic_base_url
|
|
if url:
|
|
kwargs["base_url"] = url
|
|
self._client = anthropic.AsyncAnthropic(**kwargs)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Streaming
|
|
# ------------------------------------------------------------------
|
|
|
|
async def stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
config: ModelConfig | None = None,
|
|
) -> AsyncIterator[StreamChunk]:
|
|
config = config or ModelConfig(
|
|
model_id=settings.default_model_id,
|
|
max_tokens=settings.max_tokens,
|
|
temperature=settings.temperature,
|
|
)
|
|
|
|
# Separate system message and convert OpenAI format to Claude format
|
|
system_content = ""
|
|
api_messages: list[dict[str, Any]] = []
|
|
for m in messages:
|
|
if m["role"] == "system":
|
|
system_content = m["content"]
|
|
else:
|
|
api_messages.append(m)
|
|
api_messages = self._convert_messages(api_messages)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": config.model_id or settings.default_model_id,
|
|
"max_tokens": config.max_tokens,
|
|
"temperature": config.temperature,
|
|
"messages": api_messages,
|
|
}
|
|
if system_content:
|
|
kwargs["system"] = system_content
|
|
if tools:
|
|
kwargs["tools"] = self._format_tools(tools)
|
|
|
|
# Retry con backoff sobre errores transitorios al ABRIR el stream.
|
|
# Si ya hemos empezado a emitir chunks al consumidor, NO podemos
|
|
# reintentar (el orquestador ya recibió contenido parcial).
|
|
attempt = 0
|
|
max_attempts = len(_RETRY_DELAYS) + 1
|
|
while True:
|
|
yielded_any = False
|
|
try:
|
|
async with self._client.messages.stream(**kwargs) as stream:
|
|
current_tool_id = ""
|
|
current_tool_name = ""
|
|
accumulated_args = ""
|
|
input_tokens = 0
|
|
|
|
async for event in stream:
|
|
yielded_any = True
|
|
if event.type == "message_start" and hasattr(event, "message"):
|
|
usage = getattr(event.message, "usage", None)
|
|
if usage:
|
|
input_tokens = getattr(usage, "input_tokens", 0)
|
|
|
|
if event.type == "content_block_start":
|
|
block = event.content_block
|
|
if block.type == "tool_use":
|
|
current_tool_id = block.id
|
|
current_tool_name = block.name
|
|
accumulated_args = ""
|
|
yield StreamChunk(
|
|
tool_call_id=current_tool_id,
|
|
tool_name=current_tool_name,
|
|
)
|
|
continue
|
|
|
|
if event.type == "content_block_delta":
|
|
delta = event.delta
|
|
if delta.type == "text_delta":
|
|
yield StreamChunk(delta=delta.text)
|
|
elif delta.type == "input_json_delta":
|
|
accumulated_args += delta.partial_json
|
|
yield StreamChunk(
|
|
tool_call_id=current_tool_id,
|
|
tool_name=current_tool_name,
|
|
tool_arguments=delta.partial_json,
|
|
)
|
|
continue
|
|
|
|
if event.type == "content_block_stop":
|
|
if current_tool_id and accumulated_args:
|
|
yield StreamChunk(
|
|
tool_call_id=current_tool_id,
|
|
tool_name=current_tool_name,
|
|
tool_arguments=accumulated_args,
|
|
finish_reason="tool_use",
|
|
)
|
|
current_tool_id = ""
|
|
current_tool_name = ""
|
|
accumulated_args = ""
|
|
continue
|
|
|
|
if event.type == "message_delta":
|
|
output_tokens = getattr(event.usage, "output_tokens", 0) if event.usage else 0
|
|
yield StreamChunk(
|
|
finish_reason=event.delta.stop_reason or "",
|
|
usage={
|
|
"input_tokens": input_tokens,
|
|
"output_tokens": output_tokens,
|
|
},
|
|
)
|
|
return # consumo OK, salimos del retry loop
|
|
except Exception as e:
|
|
# Si ya emitimos algo al consumidor, no podemos reintentar
|
|
# de forma segura: el contenido parcial ya viajó.
|
|
if yielded_any or not _is_transient(e) or attempt >= max_attempts - 1:
|
|
raise
|
|
wait = _RETRY_DELAYS[attempt]
|
|
logger.warning(
|
|
"Claude stream() transient error (attempt %d/%d), retrying in %.1fs: %s",
|
|
attempt + 1, max_attempts, wait, str(e)[:200],
|
|
)
|
|
await asyncio.sleep(wait)
|
|
attempt += 1
|
|
|
|
# ------------------------------------------------------------------
|
|
# Non-streaming
|
|
# ------------------------------------------------------------------
|
|
|
|
async def complete(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
config: ModelConfig | None = None,
|
|
) -> ModelResponse:
|
|
config = config or ModelConfig(
|
|
model_id=settings.default_model_id,
|
|
max_tokens=settings.max_tokens,
|
|
temperature=settings.temperature,
|
|
)
|
|
|
|
system_content = ""
|
|
api_messages: list[dict[str, Any]] = []
|
|
for m in messages:
|
|
if m["role"] == "system":
|
|
system_content = m["content"]
|
|
else:
|
|
api_messages.append(m)
|
|
api_messages = self._convert_messages(api_messages)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": config.model_id or settings.default_model_id,
|
|
"max_tokens": config.max_tokens,
|
|
"temperature": config.temperature,
|
|
"messages": api_messages,
|
|
}
|
|
if system_content:
|
|
kwargs["system"] = system_content
|
|
if tools:
|
|
kwargs["tools"] = self._format_tools(tools)
|
|
# Fuerza al modelo a usar un tool concreto para garantizar JSON por schema
|
|
# (usado por /completions con json_schema). Ver OpenAIAdapter para la variante.
|
|
force_tool = (config.extra or {}).get("force_tool")
|
|
if force_tool:
|
|
kwargs["tool_choice"] = {"type": "tool", "name": force_tool}
|
|
|
|
# Retry con backoff sobre errores transitorios (429/503/529). El proxy
|
|
# MiniMax devuelve 529 overloaded_error con cierta frecuencia bajo carga.
|
|
last_exc: Exception | None = None
|
|
for attempt in range(len(_RETRY_DELAYS) + 1):
|
|
try:
|
|
response = await self._client.messages.create(**kwargs)
|
|
break
|
|
except Exception as e:
|
|
if not _is_transient(e) or attempt == len(_RETRY_DELAYS):
|
|
raise
|
|
wait = _RETRY_DELAYS[attempt]
|
|
logger.warning(
|
|
"Claude complete() transient error (attempt %d/%d), retrying in %.1fs: %s",
|
|
attempt + 1, len(_RETRY_DELAYS) + 1, wait, str(e)[:200],
|
|
)
|
|
last_exc = e
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise last_exc or RuntimeError("Claude complete() retry exhausted")
|
|
|
|
content = ""
|
|
tool_calls: list[dict[str, Any]] = []
|
|
for block in response.content:
|
|
if block.type == "text":
|
|
content += block.text
|
|
elif block.type == "tool_use":
|
|
tool_calls.append(
|
|
{
|
|
"id": block.id,
|
|
"name": block.name,
|
|
"arguments": block.input,
|
|
}
|
|
)
|
|
|
|
return ModelResponse(
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
finish_reason=response.stop_reason or "",
|
|
usage={
|
|
"input_tokens": response.usage.input_tokens,
|
|
"output_tokens": response.usage.output_tokens,
|
|
},
|
|
raw=response,
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Token counting
|
|
# ------------------------------------------------------------------
|
|
|
|
async def count_tokens(self, text: str) -> int:
|
|
from ..context.compactor import estimate_tokens
|
|
return estimate_tokens(text)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Helpers
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _convert_messages(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Convert OpenAI-format messages to Claude format.
|
|
|
|
- role=tool → role=user with tool_result content blocks
|
|
- assistant with tool_calls → assistant with tool_use content blocks
|
|
- Consecutive same-role messages get merged (Claude requires alternating)
|
|
"""
|
|
converted: list[dict[str, Any]] = []
|
|
|
|
for m in messages:
|
|
role = m.get("role", "")
|
|
|
|
if role == "tool":
|
|
# Convert to user message with tool_result block
|
|
block = {
|
|
"type": "tool_result",
|
|
"tool_use_id": m.get("tool_call_id", ""),
|
|
"content": m.get("content", ""),
|
|
}
|
|
if m.get("is_error"):
|
|
block["is_error"] = True
|
|
# Merge with previous user message if exists
|
|
if converted and converted[-1]["role"] == "user":
|
|
content = converted[-1]["content"]
|
|
if isinstance(content, str):
|
|
converted[-1]["content"] = [{"type": "text", "text": content}, block]
|
|
elif isinstance(content, list):
|
|
content.append(block)
|
|
else:
|
|
converted[-1]["content"] = [block]
|
|
else:
|
|
converted.append({"role": "user", "content": [block]})
|
|
|
|
elif role == "assistant" and "tool_calls" in m:
|
|
# Convert tool_calls to tool_use content blocks
|
|
blocks: list[dict[str, Any]] = []
|
|
text = m.get("content", "")
|
|
if text:
|
|
blocks.append({"type": "text", "text": text})
|
|
for tc in m["tool_calls"]:
|
|
func = tc.get("function", {})
|
|
args_str = func.get("arguments", "{}")
|
|
try:
|
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
|
except (json.JSONDecodeError, TypeError):
|
|
args = {}
|
|
blocks.append({
|
|
"type": "tool_use",
|
|
"id": tc.get("id", ""),
|
|
"name": func.get("name", ""),
|
|
"input": args,
|
|
})
|
|
# Merge with previous assistant if exists
|
|
if converted and converted[-1]["role"] == "assistant":
|
|
prev = converted[-1]["content"]
|
|
if isinstance(prev, str):
|
|
converted[-1]["content"] = [{"type": "text", "text": prev}] + blocks
|
|
elif isinstance(prev, list):
|
|
prev.extend(blocks)
|
|
else:
|
|
converted[-1]["content"] = blocks
|
|
else:
|
|
converted.append({"role": "assistant", "content": blocks})
|
|
|
|
elif role == "assistant":
|
|
content = m.get("content", "")
|
|
# Merge with previous assistant
|
|
if converted and converted[-1]["role"] == "assistant":
|
|
prev = converted[-1]["content"]
|
|
if isinstance(prev, str):
|
|
converted[-1]["content"] = prev + "\n" + content if content else prev
|
|
elif isinstance(prev, list) and content:
|
|
prev.append({"type": "text", "text": content})
|
|
else:
|
|
converted.append({"role": "assistant", "content": content})
|
|
|
|
elif role == "user":
|
|
content = m.get("content", "")
|
|
# Merge with previous user
|
|
if converted and converted[-1]["role"] == "user":
|
|
prev = converted[-1]["content"]
|
|
if isinstance(prev, str) and isinstance(content, str):
|
|
converted[-1]["content"] = prev + "\n" + content
|
|
elif isinstance(prev, list) and isinstance(content, str):
|
|
prev.append({"type": "text", "text": content})
|
|
elif isinstance(prev, str) and isinstance(content, list):
|
|
converted[-1]["content"] = [{"type": "text", "text": prev}] + content
|
|
elif isinstance(prev, list) and isinstance(content, list):
|
|
prev.extend(content)
|
|
else:
|
|
converted.append({"role": role, "content": content})
|
|
else:
|
|
converted.append(m)
|
|
|
|
return converted
|
|
|
|
@staticmethod
|
|
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Convert internal tool definitions to Anthropic tool format."""
|
|
formatted: list[dict[str, Any]] = []
|
|
for tool in tools:
|
|
formatted.append(
|
|
{
|
|
"name": tool["name"],
|
|
"description": tool.get("description", ""),
|
|
"input_schema": tool.get("input_schema", tool.get("parameters", {"type": "object"})),
|
|
}
|
|
)
|
|
return formatted
|