"""OpenAI model adapter with full streaming support.""" from __future__ import annotations import json import logging from typing import Any, AsyncIterator from openai import AsyncOpenAI from ..config import settings from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk logger = logging.getLogger(__name__) class OpenAIAdapter(ModelAdapter): """Adapter for the OpenAI API (GPT-4o, o1, etc.).""" def __init__(self, api_key: str | None = None, base_url: str | None = None) -> None: kwargs: dict[str, Any] = { "api_key": api_key or settings.openai_api_key, } url = base_url or settings.openai_base_url if url: kwargs["base_url"] = url self._client = AsyncOpenAI(**kwargs) # ------------------------------------------------------------------ # Streaming # ------------------------------------------------------------------ async def stream( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, config: ModelConfig | None = None, ) -> AsyncIterator[StreamChunk]: config = config or ModelConfig( model_id=settings.default_model_id, max_tokens=settings.max_tokens, temperature=settings.temperature, ) kwargs: dict[str, Any] = { "model": config.model_id or "gpt-4o", "max_tokens": config.max_tokens, "temperature": config.temperature, "messages": messages, "stream": True, "stream_options": {"include_usage": True}, } if tools: kwargs["tools"] = self._format_tools(tools) stream = await self._client.chat.completions.create(**kwargs) tool_calls_acc: dict[int, dict[str, str]] = {} final_usage: dict[str, int] = {} async for chunk in stream: # With include_usage, the last chunk has usage but no choices if chunk.usage: final_usage = { "input_tokens": chunk.usage.prompt_tokens or 0, "output_tokens": chunk.usage.completion_tokens or 0, } choice = chunk.choices[0] if chunk.choices else None if not choice: # Usage-only chunk (last one with include_usage) — emit it if final_usage: yield StreamChunk(usage=final_usage) final_usage = {} # Only emit once continue delta = choice.delta # Text content if delta and delta.content: yield StreamChunk(delta=delta.content) # Tool calls if delta and delta.tool_calls: for tc in delta.tool_calls: idx = tc.index if idx not in tool_calls_acc: tool_calls_acc[idx] = { "id": tc.id or "", "name": "", "arguments": "", } if tc.id: tool_calls_acc[idx]["id"] = tc.id if tc.function and tc.function.name: tool_calls_acc[idx]["name"] = tc.function.name yield StreamChunk( tool_call_id=tc.id or tool_calls_acc[idx]["id"], tool_name=tc.function.name, ) if tc.function and tc.function.arguments: tool_calls_acc[idx]["arguments"] += tc.function.arguments yield StreamChunk( tool_call_id=tool_calls_acc[idx]["id"], tool_name=tool_calls_acc[idx]["name"], tool_arguments=tc.function.arguments, ) # Finish if choice.finish_reason: if choice.finish_reason == "tool_calls": for acc in tool_calls_acc.values(): yield StreamChunk( tool_call_id=acc["id"], tool_name=acc["name"], tool_arguments=acc["arguments"], finish_reason="tool_use", ) # Emit usage after tool_use chunks if final_usage: yield StreamChunk(usage=final_usage) else: yield StreamChunk( finish_reason="end_turn" if choice.finish_reason == "stop" else choice.finish_reason, usage=final_usage, ) # ------------------------------------------------------------------ # Non-streaming # ------------------------------------------------------------------ async def complete( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, config: ModelConfig | None = None, ) -> ModelResponse: config = config or ModelConfig( model_id=settings.default_model_id, max_tokens=settings.max_tokens, temperature=settings.temperature, ) kwargs: dict[str, Any] = { "model": config.model_id or "gpt-4o", "max_tokens": config.max_tokens, "temperature": config.temperature, "messages": messages, } if tools: kwargs["tools"] = self._format_tools(tools) response = await self._client.chat.completions.create(**kwargs) choice = response.choices[0] content = choice.message.content or "" tool_calls: list[dict[str, Any]] = [] if choice.message.tool_calls: for tc in choice.message.tool_calls: tool_calls.append( { "id": tc.id, "name": tc.function.name, "arguments": json.loads(tc.function.arguments) if tc.function.arguments else {}, } ) return ModelResponse( content=content, tool_calls=tool_calls, finish_reason=choice.finish_reason or "", usage={ "input_tokens": response.usage.prompt_tokens if response.usage else 0, "output_tokens": response.usage.completion_tokens if response.usage else 0, }, raw=response, ) # ------------------------------------------------------------------ # Token counting # ------------------------------------------------------------------ async def count_tokens(self, text: str) -> int: from ..context.compactor import estimate_tokens return estimate_tokens(text) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ @staticmethod def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert internal tool definitions to OpenAI function calling format.""" formatted: list[dict[str, Any]] = [] for tool in tools: formatted.append( { "type": "function", "function": { "name": tool["name"], "description": tool.get("description", ""), "parameters": tool.get( "input_schema", tool.get("parameters", {"type": "object"}) ), }, } ) return formatted