healthcare-api-mcp / server_http.py
visproj's picture
initial commit
0d10048 verified
raw
history blame
8.23 kB
"""HTTP Streaming MCP Server for Healthcare API MCP - HuggingFace Spaces deployment.
Uses HTTP streaming with chunked transfer encoding instead of SSE.
"""
import asyncio
import json
import logging
from typing import AsyncGenerator
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from starlette.middleware.cors import CORSMiddleware
from config import MCPConfig
from core import create_http_client
from providers.openfda_provider import OpenFDAProvider
from providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from providers.cms_pricing_provider import CMSPricingProvider
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI(
title="Healthcare API MCP",
description="Model Context Protocol server for public healthcare data APIs",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global state
providers = {}
http_client = None
@app.on_event("startup")
async def startup():
"""Initialize providers on startup."""
global http_client, providers
config = MCPConfig.from_env()
logger.info(f"Starting Healthcare API MCP with providers: {config.enabled_providers}")
http_client = create_http_client()
if "openfda" in config.enabled_providers:
providers["openfda"] = OpenFDAProvider(http_client, config.openfda_api_key)
await providers["openfda"].initialize()
logger.info("✅ OpenFDA provider initialized")
if "clinical_guidelines" in config.enabled_providers:
providers["clinical_guidelines"] = ClinicalGuidelinesProvider(http_client)
await providers["clinical_guidelines"].initialize()
logger.info("✅ Clinical Guidelines provider initialized")
if "cms_pricing" in config.enabled_providers:
providers["cms_pricing"] = CMSPricingProvider(http_client)
await providers["cms_pricing"].initialize()
logger.info("✅ CMS Pricing provider initialized")
logger.info(f"Healthcare API MCP ready with {len(providers)} provider(s)")
@app.on_event("shutdown")
async def shutdown():
"""Cleanup on shutdown."""
global http_client, providers
logger.info("Shutting down Healthcare API MCP...")
for provider in providers.values():
await provider.cleanup()
if http_client:
await http_client.aclose()
@app.get("/")
async def root():
"""Root endpoint with server info."""
return {
"name": "Healthcare API MCP",
"version": "1.0.0",
"protocol": "MCP over HTTP Streaming",
"providers": list(providers.keys()),
"endpoints": {
"tools": "/mcp/tools",
"call": "/mcp/call",
"stream": "/mcp/stream"
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"providers": {name: "active" for name in providers.keys()}
}
@app.get("/mcp/tools")
async def list_tools():
"""List all available MCP tools."""
tools = []
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
# Extract tool metadata
doc = tool_func.__doc__ or ""
description = doc.split("Args:")[0].strip() if "Args:" in doc else doc.strip()
# Parse parameters from docstring
params = {}
if "Args:" in doc:
args_section = doc.split("Args:")[1].split("Returns:")[0]
for line in args_section.split("\n"):
line = line.strip()
if ":" in line:
param_name = line.split(":")[0].strip()
param_desc = line.split(":", 1)[1].strip()
params[param_name] = param_desc
tools.append({
"name": tool_func.__name__,
"provider": provider_name,
"description": description,
"parameters": params
})
return {"tools": tools, "total": len(tools)}
@app.post("/mcp/call")
async def call_tool(request: Request):
"""Call an MCP tool and return result."""
body = await request.json()
tool_name = body.get("tool")
arguments = body.get("arguments", {})
if not tool_name:
return {"error": "Missing 'tool' parameter"}
logger.info(f"Calling tool: {tool_name} with args: {arguments}")
# Find and execute tool
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == tool_name:
try:
result = await tool_func(**arguments)
return {
"success": True,
"tool": tool_name,
"result": result
}
except Exception as e:
logger.exception(f"Error calling tool {tool_name}")
return {
"success": False,
"tool": tool_name,
"error": str(e),
"error_type": type(e).__name__
}
return {
"success": False,
"tool": tool_name,
"error": f"Tool not found: {tool_name}"
}
async def stream_tool_result(tool_name: str, arguments: dict) -> AsyncGenerator[str, None]:
"""Stream tool execution results as JSON chunks."""
# Send initial acknowledgment
yield json.dumps({
"type": "start",
"tool": tool_name,
"arguments": arguments
}) + "\n"
# Find and execute tool
tool_found = False
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == tool_name:
tool_found = True
try:
# Execute tool
result = await tool_func(**arguments)
# Stream result
yield json.dumps({
"type": "result",
"success": True,
"data": result
}) + "\n"
except Exception as e:
logger.exception(f"Error executing tool {tool_name}")
yield json.dumps({
"type": "error",
"success": False,
"error": str(e),
"error_type": type(e).__name__
}) + "\n"
break
if tool_found:
break
if not tool_found:
yield json.dumps({
"type": "error",
"success": False,
"error": f"Tool not found: {tool_name}"
}) + "\n"
# Send completion
yield json.dumps({"type": "complete"}) + "\n"
@app.post("/mcp/stream")
async def stream_tool_call(request: Request):
"""Stream tool execution results using HTTP chunked transfer encoding."""
body = await request.json()
tool_name = body.get("tool")
arguments = body.get("arguments", {})
if not tool_name:
return {"error": "Missing 'tool' parameter"}
logger.info(f"Streaming tool: {tool_name} with args: {arguments}")
return StreamingResponse(
stream_tool_result(tool_name, arguments),
media_type="application/x-ndjson",
headers={
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no",
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=7860,
log_level="info"
)