healthcare-api-mcp / server_http.py
visproj's picture
Update server_http.py
e880bf8 verified
raw
history blame
17.7 kB
"""MCP Streamable HTTP Server for Healthcare API MCP - HuggingFace Spaces deployment.
Implements MCP Streamable HTTP transport (JSON-RPC 2.0).
Specification: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""
import asyncio
import json
import logging
import inspect
import uuid
from typing import AsyncGenerator, get_type_hints, get_origin, get_args, Optional, Dict, Any
from datetime import datetime
from fastapi import FastAPI, Request, Header
from fastapi.responses import StreamingResponse, JSONResponse
from starlette.middleware.cors import CORSMiddleware
from config import MCPConfig
from core import create_http_client
from providers.openfda_provider import OpenFDAProvider
from providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from providers.cms_pricing_provider import CMSPricingProvider
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Healthcare API MCP",
description="Model Context Protocol server for public healthcare data APIs",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global state
providers = {}
http_client = None
sessions: Dict[str, Dict[str, Any]] = {} # Session ID -> session data
# MCP Protocol version
MCP_PROTOCOL_VERSION = "2024-11-05"
@app.on_event("startup")
async def startup():
"""Initialize providers on startup."""
global http_client, providers
config = MCPConfig.from_env()
logger.info(f"Starting Healthcare API MCP with providers: {config.enabled_providers}")
http_client = create_http_client()
if "openfda" in config.enabled_providers:
providers["openfda"] = OpenFDAProvider(http_client, config.openfda_api_key)
await providers["openfda"].initialize()
logger.info("✅ OpenFDA provider initialized")
if "clinical_guidelines" in config.enabled_providers:
providers["clinical_guidelines"] = ClinicalGuidelinesProvider(http_client)
await providers["clinical_guidelines"].initialize()
logger.info("✅ Clinical Guidelines provider initialized")
if "cms_pricing" in config.enabled_providers:
providers["cms_pricing"] = CMSPricingProvider(http_client)
await providers["cms_pricing"].initialize()
logger.info("✅ CMS Pricing provider initialized")
logger.info(f"Healthcare API MCP ready with {len(providers)} provider(s)")
@app.on_event("shutdown")
async def shutdown():
"""Cleanup on shutdown."""
global http_client, providers
logger.info("Shutting down Healthcare API MCP...")
for provider in providers.values():
await provider.cleanup()
if http_client:
await http_client.aclose()
def build_tool_schema(tool_func) -> dict:
"""
Build proper MCP tool schema with inputSchema from function signature.
Extracts:
- Function description from docstring
- Parameter types from type hints
- Required vs optional parameters (based on defaults)
- Parameter descriptions from docstring
"""
# Get function signature
sig = inspect.signature(tool_func)
# Get type hints
try:
hints = get_type_hints(tool_func)
except:
hints = {}
# Parse docstring
doc = tool_func.__doc__ or ""
description = doc.split("Args:")[0].strip() if "Args:" in doc else doc.strip()
# Parse parameter descriptions from docstring
param_descriptions = {}
if "Args:" in doc:
args_section = doc.split("Args:")[1].split("Returns:")[0]
for line in args_section.split("\n"):
line = line.strip()
if ":" in line and not line.startswith("return"):
param_name = line.split(":")[0].strip()
param_desc = line.split(":", 1)[1].strip()
param_descriptions[param_name] = param_desc
# Build input schema
properties = {}
required = []
for param_name, param in sig.parameters.items():
# Skip 'self' parameter
if param_name == "self":
continue
# Determine type
param_type = "string" # default
if param_name in hints:
hint = hints[param_name]
# Map Python types to JSON Schema types
if hint == str:
param_type = "string"
elif hint == int:
param_type = "integer"
elif hint == float:
param_type = "number"
elif hint == bool:
param_type = "boolean"
elif get_origin(hint) == list:
param_type = "array"
elif get_origin(hint) == dict:
param_type = "object"
# Build property schema
prop_schema = {"type": param_type}
if param_name in param_descriptions:
prop_schema["description"] = param_descriptions[param_name]
properties[param_name] = prop_schema
# Determine if required (no default value)
if param.default == inspect.Parameter.empty:
required.append(param_name)
# Build full tool schema
schema = {
"name": tool_func.__name__,
"description": description,
"inputSchema": {
"type": "object",
"properties": properties
}
}
if required:
schema["inputSchema"]["required"] = required
return schema
# ============================================================================
# JSON-RPC 2.0 Message Handlers
# ============================================================================
def create_jsonrpc_response(result: Any, request_id: Any) -> dict:
"""Create a JSON-RPC 2.0 success response."""
return {
"jsonrpc": "2.0",
"result": result,
"id": request_id
}
def create_jsonrpc_error(code: int, message: str, request_id: Any, data: Any = None) -> dict:
"""Create a JSON-RPC 2.0 error response."""
error = {
"code": code,
"message": message
}
if data is not None:
error["data"] = data
return {
"jsonrpc": "2.0",
"error": error,
"id": request_id
}
async def handle_initialize(params: dict, session_id: str) -> dict:
"""Handle initialize method - capability negotiation."""
client_info = params.get("clientInfo", {})
protocol_version = params.get("protocolVersion", MCP_PROTOCOL_VERSION)
logger.info(f"Initializing session {session_id} with client: {client_info.get('name', 'unknown')}")
# Store session data
sessions[session_id] = {
"initialized": True,
"client_info": client_info,
"protocol_version": protocol_version,
"created_at": datetime.utcnow().isoformat()
}
# Return server capabilities
return {
"protocolVersion": MCP_PROTOCOL_VERSION,
"capabilities": {
"tools": {
"listChanged": False # We don't support dynamic tool updates
}
},
"serverInfo": {
"name": "Healthcare API MCP",
"version": "1.0.0"
},
"instructions": "This MCP provides access to public healthcare data APIs including OpenFDA (drug safety, labels, recalls), clinical guidelines (USPSTF, CDC), and CMS pricing data."
}
async def handle_tools_list(params: dict) -> dict:
"""Handle tools/list method - return available tools."""
tools = []
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
schema = build_tool_schema(tool_func)
tools.append(schema)
return {"tools": tools}
async def handle_tools_call(params: dict) -> dict:
"""Handle tools/call method - execute a tool."""
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
raise ValueError("Missing 'name' parameter in tools/call")
logger.info(f"Calling tool: {tool_name} with args: {arguments}")
# Find and execute tool
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == tool_name:
try:
result = await tool_func(**arguments)
# Return MCP tool response format
return {
"content": [
{
"type": "text",
"text": result
}
]
}
except TypeError as e:
# Handle missing required arguments
error_msg = str(e)
if "missing" in error_msg and "required" in error_msg:
logger.error(f"Missing required arguments for {tool_name}: {error_msg}")
raise ValueError(f"Missing required arguments. {error_msg}. Provided arguments: {list(arguments.keys())}")
else:
raise
raise ValueError(f"Tool not found: {tool_name}")
async def handle_jsonrpc_request(request_data: dict, session_id: Optional[str] = None) -> dict:
"""Handle a JSON-RPC 2.0 request."""
# Validate JSON-RPC format
if request_data.get("jsonrpc") != "2.0":
return create_jsonrpc_error(-32600, "Invalid Request: jsonrpc must be '2.0'", request_data.get("id"))
method = request_data.get("method")
params = request_data.get("params", {})
request_id = request_data.get("id")
if not method:
return create_jsonrpc_error(-32600, "Invalid Request: missing method", request_id)
try:
# Route to appropriate handler
if method == "initialize":
# Create session ID if not exists
if not session_id:
session_id = str(uuid.uuid4())
result = await handle_initialize(params, session_id)
# Return response with session ID
response = create_jsonrpc_response(result, request_id)
response["_session_id"] = session_id # Include for header
return response
elif method == "tools/list":
result = await handle_tools_list(params)
return create_jsonrpc_response(result, request_id)
elif method == "tools/call":
result = await handle_tools_call(params)
return create_jsonrpc_response(result, request_id)
else:
return create_jsonrpc_error(-32601, f"Method not found: {method}", request_id)
except ValueError as e:
return create_jsonrpc_error(-32602, f"Invalid params: {str(e)}", request_id)
except Exception as e:
logger.exception(f"Error handling {method}")
return create_jsonrpc_error(-32603, f"Internal error: {str(e)}", request_id)
# ============================================================================
# MCP Streamable HTTP Endpoint
# ============================================================================
@app.post("/mcp")
async def mcp_endpoint(
request: Request,
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
):
"""
MCP Streamable HTTP endpoint (JSON-RPC 2.0).
Handles:
- initialize: Capability negotiation
- tools/list: List available tools
- tools/call: Execute a tool
"""
try:
body = await request.json()
except Exception as e:
return JSONResponse(
content=create_jsonrpc_error(-32700, f"Parse error: {str(e)}", None),
status_code=400
)
# Handle JSON-RPC request
response = await handle_jsonrpc_request(body, mcp_session_id)
# Extract session ID if present (from initialize)
session_id = response.pop("_session_id", None)
# Create response with session header if needed
headers = {}
if session_id:
headers["Mcp-Session-Id"] = session_id
return JSONResponse(content=response, headers=headers)
# ============================================================================
# Legacy REST Endpoints (for backward compatibility)
# ============================================================================
@app.get("/")
async def root_info():
"""Root endpoint with server info (GET only)."""
return {
"name": "Healthcare API MCP",
"version": "1.0.0",
"protocol": "MCP Streamable HTTP (JSON-RPC 2.0)",
"protocolVersion": MCP_PROTOCOL_VERSION,
"providers": list(providers.keys()),
"endpoints": {
"mcp": "/ or /mcp (POST) - MCP Streamable HTTP endpoint",
"legacy": {
"tools": "/mcp/tools (GET) - List tools (legacy)",
"call": "/mcp/call (POST) - Execute tool (legacy)",
"stream": "/mcp/stream (POST) - Stream results (legacy)"
}
}
}
@app.post("/")
async def root_mcp(
request: Request,
mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
):
"""Root MCP endpoint - same as /mcp for compatibility."""
return await mcp_endpoint(request, mcp_session_id)
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"providers": {name: "active" for name in providers.keys()}
}
@app.get("/mcp/tools")
async def list_tools():
"""List all available MCP tools with proper JSON Schema."""
tools = []
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
# Build proper tool schema
schema = build_tool_schema(tool_func)
schema["provider"] = provider_name
tools.append(schema)
return {"tools": tools, "total": len(tools)}
@app.post("/mcp/call")
async def call_tool(request: Request):
"""Call an MCP tool and return result."""
body = await request.json()
tool_name = body.get("tool")
arguments = body.get("arguments", {})
if not tool_name:
return {"error": "Missing 'tool' parameter"}
logger.info(f"Calling tool: {tool_name} with args: {arguments}")
# Find and execute tool
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == tool_name:
try:
result = await tool_func(**arguments)
return {
"success": True,
"tool": tool_name,
"result": result
}
except Exception as e:
logger.exception(f"Error calling tool {tool_name}")
return {
"success": False,
"tool": tool_name,
"error": str(e),
"error_type": type(e).__name__
}
return {
"success": False,
"tool": tool_name,
"error": f"Tool not found: {tool_name}"
}
async def stream_tool_result(tool_name: str, arguments: dict) -> AsyncGenerator[str, None]:
"""Stream tool execution results as JSON chunks."""
# Send initial acknowledgment
yield json.dumps({
"type": "start",
"tool": tool_name,
"arguments": arguments
}) + "\n"
# Find and execute tool
tool_found = False
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == tool_name:
tool_found = True
try:
# Execute tool
result = await tool_func(**arguments)
# Stream result
yield json.dumps({
"type": "result",
"success": True,
"data": result
}) + "\n"
except Exception as e:
logger.exception(f"Error executing tool {tool_name}")
yield json.dumps({
"type": "error",
"success": False,
"error": str(e),
"error_type": type(e).__name__
}) + "\n"
break
if tool_found:
break
if not tool_found:
yield json.dumps({
"type": "error",
"success": False,
"error": f"Tool not found: {tool_name}"
}) + "\n"
# Send completion
yield json.dumps({"type": "complete"}) + "\n"
@app.post("/mcp/stream")
async def stream_tool_call(request: Request):
"""Stream tool execution results using HTTP chunked transfer encoding."""
body = await request.json()
tool_name = body.get("tool")
arguments = body.get("arguments", {})
if not tool_name:
return {"error": "Missing 'tool' parameter"}
logger.info(f"Streaming tool: {tool_name} with args: {arguments}")
return StreamingResponse(
stream_tool_result(tool_name, arguments),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)