Spaces:
Running
Running
| """MCP Streamable HTTP Server for Healthcare API MCP - HuggingFace Spaces deployment. | |
| Implements MCP Streamable HTTP transport (JSON-RPC 2.0). | |
| Specification: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import inspect | |
| import uuid | |
| from typing import AsyncGenerator, get_type_hints, get_origin, get_args, Optional, Dict, Any | |
| from datetime import datetime | |
| from fastapi import FastAPI, Request, Header | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from starlette.middleware.cors import CORSMiddleware | |
| from config import MCPConfig | |
| from core import create_http_client | |
| from providers.openfda_provider import OpenFDAProvider | |
| from providers.clinical_guidelines_provider import ClinicalGuidelinesProvider | |
| from providers.cms_pricing_provider import CMSPricingProvider | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Healthcare API MCP", | |
| description="Model Context Protocol server for public healthcare data APIs", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global state | |
| providers = {} | |
| http_client = None | |
| sessions: Dict[str, Dict[str, Any]] = {} # Session ID -> session data | |
| # MCP Protocol version | |
| MCP_PROTOCOL_VERSION = "2024-11-05" | |
| async def startup(): | |
| """Initialize providers on startup.""" | |
| global http_client, providers | |
| config = MCPConfig.from_env() | |
| logger.info(f"Starting Healthcare API MCP with providers: {config.enabled_providers}") | |
| http_client = create_http_client() | |
| if "openfda" in config.enabled_providers: | |
| providers["openfda"] = OpenFDAProvider(http_client, config.openfda_api_key) | |
| await providers["openfda"].initialize() | |
| logger.info("✅ OpenFDA provider initialized") | |
| if "clinical_guidelines" in config.enabled_providers: | |
| providers["clinical_guidelines"] = ClinicalGuidelinesProvider(http_client) | |
| await providers["clinical_guidelines"].initialize() | |
| logger.info("✅ Clinical Guidelines provider initialized") | |
| if "cms_pricing" in config.enabled_providers: | |
| providers["cms_pricing"] = CMSPricingProvider(http_client) | |
| await providers["cms_pricing"].initialize() | |
| logger.info("✅ CMS Pricing provider initialized") | |
| logger.info(f"Healthcare API MCP ready with {len(providers)} provider(s)") | |
| async def shutdown(): | |
| """Cleanup on shutdown.""" | |
| global http_client, providers | |
| logger.info("Shutting down Healthcare API MCP...") | |
| for provider in providers.values(): | |
| await provider.cleanup() | |
| if http_client: | |
| await http_client.aclose() | |
| def build_tool_schema(tool_func) -> dict: | |
| """ | |
| Build proper MCP tool schema with inputSchema from function signature. | |
| Extracts: | |
| - Function description from docstring | |
| - Parameter types from type hints | |
| - Required vs optional parameters (based on defaults) | |
| - Parameter descriptions from docstring | |
| """ | |
| # Get function signature | |
| sig = inspect.signature(tool_func) | |
| # Get type hints | |
| try: | |
| hints = get_type_hints(tool_func) | |
| except: | |
| hints = {} | |
| # Parse docstring | |
| doc = tool_func.__doc__ or "" | |
| description = doc.split("Args:")[0].strip() if "Args:" in doc else doc.strip() | |
| # Parse parameter descriptions from docstring | |
| param_descriptions = {} | |
| if "Args:" in doc: | |
| args_section = doc.split("Args:")[1].split("Returns:")[0] | |
| for line in args_section.split("\n"): | |
| line = line.strip() | |
| if ":" in line and not line.startswith("return"): | |
| param_name = line.split(":")[0].strip() | |
| param_desc = line.split(":", 1)[1].strip() | |
| param_descriptions[param_name] = param_desc | |
| # Build input schema | |
| properties = {} | |
| required = [] | |
| for param_name, param in sig.parameters.items(): | |
| # Skip 'self' parameter | |
| if param_name == "self": | |
| continue | |
| # Determine type | |
| param_type = "string" # default | |
| if param_name in hints: | |
| hint = hints[param_name] | |
| # Map Python types to JSON Schema types | |
| if hint == str: | |
| param_type = "string" | |
| elif hint == int: | |
| param_type = "integer" | |
| elif hint == float: | |
| param_type = "number" | |
| elif hint == bool: | |
| param_type = "boolean" | |
| elif get_origin(hint) == list: | |
| param_type = "array" | |
| elif get_origin(hint) == dict: | |
| param_type = "object" | |
| # Build property schema | |
| prop_schema = {"type": param_type} | |
| if param_name in param_descriptions: | |
| prop_schema["description"] = param_descriptions[param_name] | |
| properties[param_name] = prop_schema | |
| # Determine if required (no default value) | |
| if param.default == inspect.Parameter.empty: | |
| required.append(param_name) | |
| # Build full tool schema | |
| schema = { | |
| "name": tool_func.__name__, | |
| "description": description, | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": properties | |
| } | |
| } | |
| if required: | |
| schema["inputSchema"]["required"] = required | |
| return schema | |
| # ============================================================================ | |
| # JSON-RPC 2.0 Message Handlers | |
| # ============================================================================ | |
| def create_jsonrpc_response(result: Any, request_id: Any) -> dict: | |
| """Create a JSON-RPC 2.0 success response.""" | |
| return { | |
| "jsonrpc": "2.0", | |
| "result": result, | |
| "id": request_id | |
| } | |
| def create_jsonrpc_error(code: int, message: str, request_id: Any, data: Any = None) -> dict: | |
| """Create a JSON-RPC 2.0 error response.""" | |
| error = { | |
| "code": code, | |
| "message": message | |
| } | |
| if data is not None: | |
| error["data"] = data | |
| return { | |
| "jsonrpc": "2.0", | |
| "error": error, | |
| "id": request_id | |
| } | |
| async def handle_initialize(params: dict, session_id: str) -> dict: | |
| """Handle initialize method - capability negotiation.""" | |
| client_info = params.get("clientInfo", {}) | |
| protocol_version = params.get("protocolVersion", MCP_PROTOCOL_VERSION) | |
| logger.info(f"Initializing session {session_id} with client: {client_info.get('name', 'unknown')}") | |
| # Store session data | |
| sessions[session_id] = { | |
| "initialized": True, | |
| "client_info": client_info, | |
| "protocol_version": protocol_version, | |
| "created_at": datetime.utcnow().isoformat() | |
| } | |
| # Return server capabilities | |
| return { | |
| "protocolVersion": MCP_PROTOCOL_VERSION, | |
| "capabilities": { | |
| "tools": { | |
| "listChanged": False # We don't support dynamic tool updates | |
| } | |
| }, | |
| "serverInfo": { | |
| "name": "Healthcare API MCP", | |
| "version": "1.0.0" | |
| }, | |
| "instructions": "This MCP provides access to public healthcare data APIs including OpenFDA (drug safety, labels, recalls), clinical guidelines (USPSTF, CDC), and CMS pricing data." | |
| } | |
| async def handle_tools_list(params: dict) -> dict: | |
| """Handle tools/list method - return available tools.""" | |
| tools = [] | |
| for provider_name, provider in providers.items(): | |
| provider_tools = provider.get_tools() | |
| for tool_func in provider_tools: | |
| schema = build_tool_schema(tool_func) | |
| tools.append(schema) | |
| return {"tools": tools} | |
| async def handle_tools_call(params: dict) -> dict: | |
| """Handle tools/call method - execute a tool.""" | |
| tool_name = params.get("name") | |
| arguments = params.get("arguments", {}) | |
| if not tool_name: | |
| raise ValueError("Missing 'name' parameter in tools/call") | |
| logger.info(f"Calling tool: {tool_name} with args: {arguments}") | |
| # Find and execute tool | |
| for provider_name, provider in providers.items(): | |
| provider_tools = provider.get_tools() | |
| for tool_func in provider_tools: | |
| if tool_func.__name__ == tool_name: | |
| try: | |
| result = await tool_func(**arguments) | |
| # Return MCP tool response format | |
| return { | |
| "content": [ | |
| { | |
| "type": "text", | |
| "text": result | |
| } | |
| ] | |
| } | |
| except TypeError as e: | |
| # Handle missing required arguments | |
| error_msg = str(e) | |
| if "missing" in error_msg and "required" in error_msg: | |
| logger.error(f"Missing required arguments for {tool_name}: {error_msg}") | |
| raise ValueError(f"Missing required arguments. {error_msg}. Provided arguments: {list(arguments.keys())}") | |
| else: | |
| raise | |
| raise ValueError(f"Tool not found: {tool_name}") | |
| async def handle_jsonrpc_request(request_data: dict, session_id: Optional[str] = None) -> dict: | |
| """Handle a JSON-RPC 2.0 request.""" | |
| # Validate JSON-RPC format | |
| if request_data.get("jsonrpc") != "2.0": | |
| return create_jsonrpc_error(-32600, "Invalid Request: jsonrpc must be '2.0'", request_data.get("id")) | |
| method = request_data.get("method") | |
| params = request_data.get("params", {}) | |
| request_id = request_data.get("id") | |
| if not method: | |
| return create_jsonrpc_error(-32600, "Invalid Request: missing method", request_id) | |
| try: | |
| # Route to appropriate handler | |
| if method == "initialize": | |
| # Create session ID if not exists | |
| if not session_id: | |
| session_id = str(uuid.uuid4()) | |
| result = await handle_initialize(params, session_id) | |
| # Return response with session ID | |
| response = create_jsonrpc_response(result, request_id) | |
| response["_session_id"] = session_id # Include for header | |
| return response | |
| elif method == "tools/list": | |
| result = await handle_tools_list(params) | |
| return create_jsonrpc_response(result, request_id) | |
| elif method == "tools/call": | |
| result = await handle_tools_call(params) | |
| return create_jsonrpc_response(result, request_id) | |
| else: | |
| return create_jsonrpc_error(-32601, f"Method not found: {method}", request_id) | |
| except ValueError as e: | |
| return create_jsonrpc_error(-32602, f"Invalid params: {str(e)}", request_id) | |
| except Exception as e: | |
| logger.exception(f"Error handling {method}") | |
| return create_jsonrpc_error(-32603, f"Internal error: {str(e)}", request_id) | |
| # ============================================================================ | |
| # MCP Streamable HTTP Endpoint | |
| # ============================================================================ | |
| async def mcp_endpoint( | |
| request: Request, | |
| mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id") | |
| ): | |
| """ | |
| MCP Streamable HTTP endpoint (JSON-RPC 2.0). | |
| Handles: | |
| - initialize: Capability negotiation | |
| - tools/list: List available tools | |
| - tools/call: Execute a tool | |
| """ | |
| try: | |
| body = await request.json() | |
| except Exception as e: | |
| return JSONResponse( | |
| content=create_jsonrpc_error(-32700, f"Parse error: {str(e)}", None), | |
| status_code=400 | |
| ) | |
| # Handle JSON-RPC request | |
| response = await handle_jsonrpc_request(body, mcp_session_id) | |
| # Extract session ID if present (from initialize) | |
| session_id = response.pop("_session_id", None) | |
| # Create response with session header if needed | |
| headers = {} | |
| if session_id: | |
| headers["Mcp-Session-Id"] = session_id | |
| return JSONResponse(content=response, headers=headers) | |
| # ============================================================================ | |
| # Legacy REST Endpoints (for backward compatibility) | |
| # ============================================================================ | |
| async def root_info(): | |
| """Root endpoint with server info (GET only).""" | |
| return { | |
| "name": "Healthcare API MCP", | |
| "version": "1.0.0", | |
| "protocol": "MCP Streamable HTTP (JSON-RPC 2.0)", | |
| "protocolVersion": MCP_PROTOCOL_VERSION, | |
| "providers": list(providers.keys()), | |
| "endpoints": { | |
| "mcp": "/ or /mcp (POST) - MCP Streamable HTTP endpoint", | |
| "legacy": { | |
| "tools": "/mcp/tools (GET) - List tools (legacy)", | |
| "call": "/mcp/call (POST) - Execute tool (legacy)", | |
| "stream": "/mcp/stream (POST) - Stream results (legacy)" | |
| } | |
| } | |
| } | |
| async def root_mcp( | |
| request: Request, | |
| mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id") | |
| ): | |
| """Root MCP endpoint - same as /mcp for compatibility.""" | |
| return await mcp_endpoint(request, mcp_session_id) | |
| async def health_check(): | |
| """Health check endpoint.""" | |
| return { | |
| "status": "healthy", | |
| "providers": {name: "active" for name in providers.keys()} | |
| } | |
| async def list_tools(): | |
| """List all available MCP tools with proper JSON Schema.""" | |
| tools = [] | |
| for provider_name, provider in providers.items(): | |
| provider_tools = provider.get_tools() | |
| for tool_func in provider_tools: | |
| # Build proper tool schema | |
| schema = build_tool_schema(tool_func) | |
| schema["provider"] = provider_name | |
| tools.append(schema) | |
| return {"tools": tools, "total": len(tools)} | |
| async def call_tool(request: Request): | |
| """Call an MCP tool and return result.""" | |
| body = await request.json() | |
| tool_name = body.get("tool") | |
| arguments = body.get("arguments", {}) | |
| if not tool_name: | |
| return {"error": "Missing 'tool' parameter"} | |
| logger.info(f"Calling tool: {tool_name} with args: {arguments}") | |
| # Find and execute tool | |
| for provider_name, provider in providers.items(): | |
| provider_tools = provider.get_tools() | |
| for tool_func in provider_tools: | |
| if tool_func.__name__ == tool_name: | |
| try: | |
| result = await tool_func(**arguments) | |
| return { | |
| "success": True, | |
| "tool": tool_name, | |
| "result": result | |
| } | |
| except Exception as e: | |
| logger.exception(f"Error calling tool {tool_name}") | |
| return { | |
| "success": False, | |
| "tool": tool_name, | |
| "error": str(e), | |
| "error_type": type(e).__name__ | |
| } | |
| return { | |
| "success": False, | |
| "tool": tool_name, | |
| "error": f"Tool not found: {tool_name}" | |
| } | |
| async def stream_tool_result(tool_name: str, arguments: dict) -> AsyncGenerator[str, None]: | |
| """Stream tool execution results as JSON chunks.""" | |
| # Send initial acknowledgment | |
| yield json.dumps({ | |
| "type": "start", | |
| "tool": tool_name, | |
| "arguments": arguments | |
| }) + "\n" | |
| # Find and execute tool | |
| tool_found = False | |
| for provider_name, provider in providers.items(): | |
| provider_tools = provider.get_tools() | |
| for tool_func in provider_tools: | |
| if tool_func.__name__ == tool_name: | |
| tool_found = True | |
| try: | |
| # Execute tool | |
| result = await tool_func(**arguments) | |
| # Stream result | |
| yield json.dumps({ | |
| "type": "result", | |
| "success": True, | |
| "data": result | |
| }) + "\n" | |
| except Exception as e: | |
| logger.exception(f"Error executing tool {tool_name}") | |
| yield json.dumps({ | |
| "type": "error", | |
| "success": False, | |
| "error": str(e), | |
| "error_type": type(e).__name__ | |
| }) + "\n" | |
| break | |
| if tool_found: | |
| break | |
| if not tool_found: | |
| yield json.dumps({ | |
| "type": "error", | |
| "success": False, | |
| "error": f"Tool not found: {tool_name}" | |
| }) + "\n" | |
| # Send completion | |
| yield json.dumps({"type": "complete"}) + "\n" | |
| async def stream_tool_call(request: Request): | |
| """Stream tool execution results using HTTP chunked transfer encoding.""" | |
| body = await request.json() | |
| tool_name = body.get("tool") | |
| arguments = body.get("arguments", {}) | |
| if not tool_name: | |
| return {"error": "Missing 'tool' parameter"} | |
| logger.info(f"Streaming tool: {tool_name} with args: {arguments}") | |
| return StreamingResponse( | |
| stream_tool_result(tool_name, arguments), | |
| media_type="application/x-ndjson", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "X-Accel-Buffering": "no", | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |