Initial commit
This commit is contained in:
291
src/mcp/client.py
Normal file
291
src/mcp/client.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""MCP (Model Context Protocol) client — stdio transport.
|
||||
|
||||
Manages subprocess lifecycle, JSON-RPC request/response, timeouts,
|
||||
and a tool registry populated from the server's capabilities.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from ..config import settings
|
||||
from ..models.tools import ToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPClientError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Stdio-based MCP client with full lifecycle management."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str | None = None,
|
||||
args: list[str] | None = None,
|
||||
timeout: float | None = None,
|
||||
startup_timeout: float | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self._command = command or settings.mcp_server_command
|
||||
self._args = args if args is not None else list(settings.mcp_server_args)
|
||||
self._timeout = timeout or settings.mcp_timeout_seconds
|
||||
self._startup_timeout = startup_timeout or settings.mcp_startup_timeout_seconds
|
||||
# Inherit current env + any overrides (passes ACAI_* vars to MCP server)
|
||||
self._env = {**os.environ, **(env or {})}
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
self._tools: dict[str, ToolDefinition] = {}
|
||||
self._pending: dict[str, asyncio.Future[dict[str, Any]]] = {}
|
||||
self._reader_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, ToolDefinition]:
|
||||
return dict(self._tools)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running and self._process is not None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Lifecycle
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the MCP server subprocess and discover tools."""
|
||||
if not self._command:
|
||||
logger.warning("No MCP server command configured — skipping start")
|
||||
return
|
||||
|
||||
logger.info("Starting MCP server: %s %s", self._command, self._args)
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
self._command,
|
||||
*self._args,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=self._env,
|
||||
)
|
||||
self._running = True
|
||||
self._reader_task = asyncio.create_task(self._read_loop())
|
||||
|
||||
# Initialize
|
||||
try:
|
||||
init_result = await asyncio.wait_for(
|
||||
self._send_request("initialize", {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "agentic-microservice", "version": "1.0.0"},
|
||||
}),
|
||||
timeout=self._startup_timeout,
|
||||
)
|
||||
logger.info("MCP initialized: %s", init_result)
|
||||
|
||||
# Send initialized notification
|
||||
await self._send_notification("notifications/initialized", {})
|
||||
|
||||
# Discover tools
|
||||
tools_result = await asyncio.wait_for(
|
||||
self._send_request("tools/list", {}),
|
||||
timeout=self._startup_timeout,
|
||||
)
|
||||
self._register_tools(tools_result)
|
||||
logger.info("Discovered %d MCP tools", len(self._tools))
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("MCP server startup timed out")
|
||||
await self.stop()
|
||||
raise MCPClientError("MCP server startup timed out")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Gracefully stop the MCP server."""
|
||||
self._running = False
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._process:
|
||||
try:
|
||||
if self._process.stdin:
|
||||
self._process.stdin.close()
|
||||
self._process.terminate()
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
self._process.kill()
|
||||
self._process = None
|
||||
|
||||
# Cancel any pending requests
|
||||
for fut in self._pending.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._pending.clear()
|
||||
self._tools.clear()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool execution
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def call_tool(
|
||||
self, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Call a tool on the MCP server with timeout."""
|
||||
if not self.is_running:
|
||||
raise MCPClientError("MCP client is not running")
|
||||
|
||||
if tool_name not in self._tools:
|
||||
raise MCPClientError(f"Unknown tool: {tool_name}")
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._send_request("tools/call", {
|
||||
"name": tool_name,
|
||||
"arguments": arguments,
|
||||
}),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPClientError(
|
||||
f"Tool '{tool_name}' timed out after {self._timeout}s"
|
||||
)
|
||||
|
||||
def get_tool_definitions(self) -> list[dict[str, Any]]:
|
||||
"""Return tool definitions in a format suitable for model adapters."""
|
||||
definitions: list[dict[str, Any]] = []
|
||||
for tool in self._tools.values():
|
||||
definitions.append({
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": tool.input_schema,
|
||||
})
|
||||
return definitions
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# JSON-RPC transport
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_request(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Send a JSON-RPC request and await the response."""
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
future: asyncio.Future[dict[str, Any]] = loop.create_future()
|
||||
self._pending[request_id] = future
|
||||
|
||||
await self._write_message(message)
|
||||
|
||||
try:
|
||||
return await future
|
||||
finally:
|
||||
self._pending.pop(request_id, None)
|
||||
|
||||
async def _send_notification(
|
||||
self, method: str, params: dict[str, Any]
|
||||
) -> None:
|
||||
"""Send a JSON-RPC notification (no response expected)."""
|
||||
message = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
await self._write_message(message)
|
||||
|
||||
async def _write_message(self, message: dict[str, Any]) -> None:
|
||||
"""Write a JSON-RPC message to the server's stdin."""
|
||||
if not self._process or not self._process.stdin:
|
||||
raise MCPClientError("MCP process stdin not available")
|
||||
|
||||
data = json.dumps(message) + "\n"
|
||||
self._process.stdin.write(data.encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
async def _read_loop(self) -> None:
|
||||
"""Continuously read JSON-RPC responses from stdout."""
|
||||
if not self._process or not self._process.stdout:
|
||||
return
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
logger.warning("MCP server stdout closed")
|
||||
break
|
||||
|
||||
line_str = line.decode().strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
message = json.loads(line_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Non-JSON MCP output: %s", line_str[:200])
|
||||
continue
|
||||
|
||||
self._handle_message(message)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("MCP read loop error")
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
def _handle_message(self, message: dict[str, Any]) -> None:
|
||||
"""Route an incoming JSON-RPC message."""
|
||||
msg_id = message.get("id")
|
||||
|
||||
if msg_id and msg_id in self._pending:
|
||||
future = self._pending[msg_id]
|
||||
if future.done():
|
||||
return
|
||||
|
||||
if "error" in message:
|
||||
future.set_exception(
|
||||
MCPClientError(
|
||||
f"MCP error {message['error'].get('code')}: "
|
||||
f"{message['error'].get('message')}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
future.set_result(message.get("result", {}))
|
||||
elif "method" in message:
|
||||
# Server-initiated notification — log it
|
||||
logger.debug(
|
||||
"MCP notification: %s", message.get("method")
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool registry
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _register_tools(self, tools_result: dict[str, Any]) -> None:
|
||||
"""Parse tools/list response and populate the registry."""
|
||||
raw_tools = tools_result.get("tools", [])
|
||||
for t in raw_tools:
|
||||
name = t.get("name", "")
|
||||
if not name:
|
||||
continue
|
||||
self._tools[name] = ToolDefinition(
|
||||
name=name,
|
||||
description=t.get("description", ""),
|
||||
input_schema=t.get("inputSchema", {}),
|
||||
server_name="mcp",
|
||||
)
|
||||
Reference in New Issue
Block a user