Initial commit
This commit is contained in:
197
src/adapters/openai_adapter.py
Normal file
197
src/adapters/openai_adapter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""OpenAI model adapter with full streaming support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from ..config import settings
|
||||
from .base import ModelAdapter, ModelConfig, ModelResponse, StreamChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAdapter(ModelAdapter):
|
||||
"""Adapter for the OpenAI API (GPT-4o, o1, etc.)."""
|
||||
|
||||
def __init__(self, api_key: str | None = None) -> None:
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or settings.openai_api_key,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> AsyncIterator[StreamChunk]:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
|
||||
tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
|
||||
async for chunk in stream:
|
||||
choice = chunk.choices[0] if chunk.choices else None
|
||||
if not choice:
|
||||
continue
|
||||
|
||||
delta = choice.delta
|
||||
|
||||
# Text content
|
||||
if delta and delta.content:
|
||||
yield StreamChunk(delta=delta.content)
|
||||
|
||||
# Tool calls
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_acc:
|
||||
tool_calls_acc[idx] = {
|
||||
"id": tc.id or "",
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
if tc.id:
|
||||
tool_calls_acc[idx]["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
tool_calls_acc[idx]["name"] = tc.function.name
|
||||
yield StreamChunk(
|
||||
tool_call_id=tc.id or tool_calls_acc[idx]["id"],
|
||||
tool_name=tc.function.name,
|
||||
)
|
||||
if tc.function and tc.function.arguments:
|
||||
tool_calls_acc[idx]["arguments"] += tc.function.arguments
|
||||
yield StreamChunk(
|
||||
tool_call_id=tool_calls_acc[idx]["id"],
|
||||
tool_name=tool_calls_acc[idx]["name"],
|
||||
tool_arguments=tc.function.arguments,
|
||||
)
|
||||
|
||||
# Finish
|
||||
if choice.finish_reason:
|
||||
if choice.finish_reason == "tool_calls":
|
||||
for acc in tool_calls_acc.values():
|
||||
yield StreamChunk(
|
||||
tool_call_id=acc["id"],
|
||||
tool_name=acc["name"],
|
||||
tool_arguments=acc["arguments"],
|
||||
finish_reason="tool_use",
|
||||
)
|
||||
else:
|
||||
yield StreamChunk(
|
||||
finish_reason="end_turn"
|
||||
if choice.finish_reason == "stop"
|
||||
else choice.finish_reason,
|
||||
usage={
|
||||
"output_tokens": chunk.usage.completion_tokens
|
||||
if chunk.usage
|
||||
else 0
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
) -> ModelResponse:
|
||||
config = config or ModelConfig(
|
||||
model_id=settings.default_model_id,
|
||||
max_tokens=settings.max_tokens,
|
||||
temperature=settings.temperature,
|
||||
)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": config.model_id or "gpt-4o",
|
||||
"max_tokens": config.max_tokens,
|
||||
"temperature": config.temperature,
|
||||
"messages": messages,
|
||||
}
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools(tools)
|
||||
|
||||
response = await self._client.chat.completions.create(**kwargs)
|
||||
choice = response.choices[0]
|
||||
|
||||
content = choice.message.content or ""
|
||||
tool_calls: list[dict[str, Any]] = []
|
||||
|
||||
if choice.message.tool_calls:
|
||||
for tc in choice.message.tool_calls:
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"arguments": json.loads(tc.function.arguments)
|
||||
if tc.function.arguments
|
||||
else {},
|
||||
}
|
||||
)
|
||||
|
||||
return ModelResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "",
|
||||
usage={
|
||||
"input_tokens": response.usage.prompt_tokens if response.usage else 0,
|
||||
"output_tokens": response.usage.completion_tokens if response.usage else 0,
|
||||
},
|
||||
raw=response,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Token counting
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def count_tokens(self, text: str) -> int:
|
||||
from ..context.compactor import estimate_tokens
|
||||
return estimate_tokens(text)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _format_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Convert internal tool definitions to OpenAI function calling format."""
|
||||
formatted: list[dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
formatted.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool["name"],
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get(
|
||||
"input_schema", tool.get("parameters", {"type": "object"})
|
||||
),
|
||||
},
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
Reference in New Issue
Block a user