"""MCP Streamable HTTP Server for Healthcare API MCP - HuggingFace Spaces deployment. Implements MCP Streamable HTTP transport (JSON-RPC 2.0). Specification: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports """ import asyncio import json import logging import inspect import uuid from typing import AsyncGenerator, get_type_hints, get_origin, get_args, Optional, Dict, Any from datetime import datetime from fastapi import FastAPI, Request, Header from fastapi.responses import StreamingResponse, JSONResponse from starlette.middleware.cors import CORSMiddleware from config import MCPConfig from core import create_http_client from providers.openfda_provider import OpenFDAProvider from providers.clinical_guidelines_provider import ClinicalGuidelinesProvider from providers.cms_pricing_provider import CMSPricingProvider logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI( title="Healthcare API MCP", description="Model Context Protocol server for public healthcare data APIs", version="1.0.0" ) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global state providers = {} http_client = None sessions: Dict[str, Dict[str, Any]] = {} # Session ID -> session data # MCP Protocol version MCP_PROTOCOL_VERSION = "2024-11-05" @app.on_event("startup") async def startup(): """Initialize providers on startup.""" global http_client, providers config = MCPConfig.from_env() logger.info(f"Starting Healthcare API MCP with providers: {config.enabled_providers}") http_client = create_http_client() if "openfda" in config.enabled_providers: providers["openfda"] = OpenFDAProvider(http_client, config.openfda_api_key) await providers["openfda"].initialize() logger.info("✅ OpenFDA provider initialized") if "clinical_guidelines" in config.enabled_providers: providers["clinical_guidelines"] = ClinicalGuidelinesProvider(http_client) await providers["clinical_guidelines"].initialize() logger.info("✅ Clinical Guidelines provider initialized") if "cms_pricing" in config.enabled_providers: providers["cms_pricing"] = CMSPricingProvider(http_client) await providers["cms_pricing"].initialize() logger.info("✅ CMS Pricing provider initialized") logger.info(f"Healthcare API MCP ready with {len(providers)} provider(s)") @app.on_event("shutdown") async def shutdown(): """Cleanup on shutdown.""" global http_client, providers logger.info("Shutting down Healthcare API MCP...") for provider in providers.values(): await provider.cleanup() if http_client: await http_client.aclose() def build_tool_schema(tool_func) -> dict: """ Build proper MCP tool schema with inputSchema from function signature. Extracts: - Function description from docstring - Parameter types from type hints - Required vs optional parameters (based on defaults) - Parameter descriptions from docstring """ # Get function signature sig = inspect.signature(tool_func) # Get type hints try: hints = get_type_hints(tool_func) except: hints = {} # Parse docstring doc = tool_func.__doc__ or "" description = doc.split("Args:")[0].strip() if "Args:" in doc else doc.strip() # Parse parameter descriptions from docstring param_descriptions = {} if "Args:" in doc: args_section = doc.split("Args:")[1].split("Returns:")[0] for line in args_section.split("\n"): line = line.strip() if ":" in line and not line.startswith("return"): param_name = line.split(":")[0].strip() param_desc = line.split(":", 1)[1].strip() param_descriptions[param_name] = param_desc # Build input schema properties = {} required = [] for param_name, param in sig.parameters.items(): # Skip 'self' parameter if param_name == "self": continue # Determine type param_type = "string" # default if param_name in hints: hint = hints[param_name] # Map Python types to JSON Schema types if hint == str: param_type = "string" elif hint == int: param_type = "integer" elif hint == float: param_type = "number" elif hint == bool: param_type = "boolean" elif get_origin(hint) == list: param_type = "array" elif get_origin(hint) == dict: param_type = "object" # Build property schema prop_schema = {"type": param_type} if param_name in param_descriptions: prop_schema["description"] = param_descriptions[param_name] properties[param_name] = prop_schema # Determine if required (no default value) if param.default == inspect.Parameter.empty: required.append(param_name) # Build full tool schema schema = { "name": tool_func.__name__, "description": description, "inputSchema": { "type": "object", "properties": properties } } if required: schema["inputSchema"]["required"] = required return schema # ============================================================================ # JSON-RPC 2.0 Message Handlers # ============================================================================ def create_jsonrpc_response(result: Any, request_id: Any) -> dict: """Create a JSON-RPC 2.0 success response.""" return { "jsonrpc": "2.0", "result": result, "id": request_id } def create_jsonrpc_error(code: int, message: str, request_id: Any, data: Any = None) -> dict: """Create a JSON-RPC 2.0 error response.""" error = { "code": code, "message": message } if data is not None: error["data"] = data return { "jsonrpc": "2.0", "error": error, "id": request_id } async def handle_initialize(params: dict, session_id: str) -> dict: """Handle initialize method - capability negotiation.""" client_info = params.get("clientInfo", {}) protocol_version = params.get("protocolVersion", MCP_PROTOCOL_VERSION) logger.info(f"Initializing session {session_id} with client: {client_info.get('name', 'unknown')}") # Store session data sessions[session_id] = { "initialized": True, "client_info": client_info, "protocol_version": protocol_version, "created_at": datetime.utcnow().isoformat() } # Return server capabilities return { "protocolVersion": MCP_PROTOCOL_VERSION, "capabilities": { "tools": { "listChanged": False # We don't support dynamic tool updates } }, "serverInfo": { "name": "Healthcare API MCP", "version": "1.0.0" }, "instructions": "This MCP provides access to public healthcare data APIs including OpenFDA (drug safety, labels, recalls), clinical guidelines (USPSTF, CDC), and CMS pricing data." } async def handle_tools_list(params: dict) -> dict: """Handle tools/list method - return available tools.""" tools = [] for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: schema = build_tool_schema(tool_func) tools.append(schema) return {"tools": tools} async def handle_tools_call(params: dict) -> dict: """Handle tools/call method - execute a tool.""" tool_name = params.get("name") arguments = params.get("arguments", {}) if not tool_name: raise ValueError("Missing 'name' parameter in tools/call") logger.info(f"Calling tool: {tool_name} with args: {arguments}") # Find and execute tool for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: if tool_func.__name__ == tool_name: try: result = await tool_func(**arguments) # Return MCP tool response format return { "content": [ { "type": "text", "text": result } ] } except TypeError as e: # Handle missing required arguments error_msg = str(e) if "missing" in error_msg and "required" in error_msg: logger.error(f"Missing required arguments for {tool_name}: {error_msg}") raise ValueError(f"Missing required arguments. {error_msg}. Provided arguments: {list(arguments.keys())}") else: raise raise ValueError(f"Tool not found: {tool_name}") async def handle_jsonrpc_request(request_data: dict, session_id: Optional[str] = None) -> dict: """Handle a JSON-RPC 2.0 request.""" # Validate JSON-RPC format if request_data.get("jsonrpc") != "2.0": return create_jsonrpc_error(-32600, "Invalid Request: jsonrpc must be '2.0'", request_data.get("id")) method = request_data.get("method") params = request_data.get("params", {}) request_id = request_data.get("id") if not method: return create_jsonrpc_error(-32600, "Invalid Request: missing method", request_id) try: # Route to appropriate handler if method == "initialize": # Create session ID if not exists if not session_id: session_id = str(uuid.uuid4()) result = await handle_initialize(params, session_id) # Return response with session ID response = create_jsonrpc_response(result, request_id) response["_session_id"] = session_id # Include for header return response elif method == "tools/list": result = await handle_tools_list(params) return create_jsonrpc_response(result, request_id) elif method == "tools/call": result = await handle_tools_call(params) return create_jsonrpc_response(result, request_id) else: return create_jsonrpc_error(-32601, f"Method not found: {method}", request_id) except ValueError as e: return create_jsonrpc_error(-32602, f"Invalid params: {str(e)}", request_id) except Exception as e: logger.exception(f"Error handling {method}") return create_jsonrpc_error(-32603, f"Internal error: {str(e)}", request_id) # ============================================================================ # MCP Streamable HTTP Endpoint # ============================================================================ @app.post("/mcp") async def mcp_endpoint( request: Request, mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id") ): """ MCP Streamable HTTP endpoint (JSON-RPC 2.0). Handles: - initialize: Capability negotiation - tools/list: List available tools - tools/call: Execute a tool """ try: body = await request.json() except Exception as e: return JSONResponse( content=create_jsonrpc_error(-32700, f"Parse error: {str(e)}", None), status_code=400 ) # Handle JSON-RPC request response = await handle_jsonrpc_request(body, mcp_session_id) # Extract session ID if present (from initialize) session_id = response.pop("_session_id", None) # Create response with session header if needed headers = {} if session_id: headers["Mcp-Session-Id"] = session_id return JSONResponse(content=response, headers=headers) # ============================================================================ # Legacy REST Endpoints (for backward compatibility) # ============================================================================ @app.get("/") async def root_info(): """Root endpoint with server info (GET only).""" return { "name": "Healthcare API MCP", "version": "1.0.0", "protocol": "MCP Streamable HTTP (JSON-RPC 2.0)", "protocolVersion": MCP_PROTOCOL_VERSION, "providers": list(providers.keys()), "endpoints": { "mcp": "/ or /mcp (POST) - MCP Streamable HTTP endpoint", "legacy": { "tools": "/mcp/tools (GET) - List tools (legacy)", "call": "/mcp/call (POST) - Execute tool (legacy)", "stream": "/mcp/stream (POST) - Stream results (legacy)" } } } @app.post("/") async def root_mcp( request: Request, mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id") ): """Root MCP endpoint - same as /mcp for compatibility.""" return await mcp_endpoint(request, mcp_session_id) @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "providers": {name: "active" for name in providers.keys()} } @app.get("/mcp/tools") async def list_tools(): """List all available MCP tools with proper JSON Schema.""" tools = [] for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: # Build proper tool schema schema = build_tool_schema(tool_func) schema["provider"] = provider_name tools.append(schema) return {"tools": tools, "total": len(tools)} @app.post("/mcp/call") async def call_tool(request: Request): """Call an MCP tool and return result.""" body = await request.json() tool_name = body.get("tool") arguments = body.get("arguments", {}) if not tool_name: return {"error": "Missing 'tool' parameter"} logger.info(f"Calling tool: {tool_name} with args: {arguments}") # Find and execute tool for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: if tool_func.__name__ == tool_name: try: result = await tool_func(**arguments) return { "success": True, "tool": tool_name, "result": result } except Exception as e: logger.exception(f"Error calling tool {tool_name}") return { "success": False, "tool": tool_name, "error": str(e), "error_type": type(e).__name__ } return { "success": False, "tool": tool_name, "error": f"Tool not found: {tool_name}" } async def stream_tool_result(tool_name: str, arguments: dict) -> AsyncGenerator[str, None]: """Stream tool execution results as JSON chunks.""" # Send initial acknowledgment yield json.dumps({ "type": "start", "tool": tool_name, "arguments": arguments }) + "\n" # Find and execute tool tool_found = False for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: if tool_func.__name__ == tool_name: tool_found = True try: # Execute tool result = await tool_func(**arguments) # Stream result yield json.dumps({ "type": "result", "success": True, "data": result }) + "\n" except Exception as e: logger.exception(f"Error executing tool {tool_name}") yield json.dumps({ "type": "error", "success": False, "error": str(e), "error_type": type(e).__name__ }) + "\n" break if tool_found: break if not tool_found: yield json.dumps({ "type": "error", "success": False, "error": f"Tool not found: {tool_name}" }) + "\n" # Send completion yield json.dumps({"type": "complete"}) + "\n" @app.post("/mcp/stream") async def stream_tool_call(request: Request): """Stream tool execution results using HTTP chunked transfer encoding.""" body = await request.json() tool_name = body.get("tool") arguments = body.get("arguments", {}) if not tool_name: return {"error": "Missing 'tool' parameter"} logger.info(f"Streaming tool: {tool_name} with args: {arguments}") return StreamingResponse( stream_tool_result(tool_name, arguments), media_type="application/x-ndjson", headers={ "Cache-Control": "no-cache", "X-Accel-Buffering": "no", } ) if __name__ == "__main__": import uvicorn uvicorn.run( app, host="0.0.0.0", port=7860, log_level="info" )