healthcare-api-mcp / server.py
visproj's picture
initial commit
0d10048 verified
raw
history blame
7.57 kB
"""Healthcare API MCP Server - Public APIs for FDA, Clinical Guidelines, and CMS."""
import asyncio
import logging
from typing import List, Dict, Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from .config import MCPConfig
from .core import create_http_client
from .providers import get_provider_class
from .providers.openfda_provider import OpenFDAProvider
from .providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from .providers.cms_pricing_provider import CMSPricingProvider
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize MCP server
app = Server("healthcare-api-mcp")
# Global state
providers = {}
http_client = None
def create_tool_schema(func: callable, provider_name: str) -> Tool:
"""Create MCP tool schema from provider function."""
# Get function metadata
doc = func.__doc__ or ""
lines = [line.strip() for line in doc.split('\n') if line.strip()]
# Extract description (first non-empty line)
description = lines[0] if lines else func.__name__
# Extract parameter info from docstring
# Parse Args section
params_start = None
params_end = None
for i, line in enumerate(lines):
if line.startswith('Args:'):
params_start = i + 1
elif params_start and line.startswith('Returns:'):
params_end = i
break
# Build input schema
properties = {}
required = []
if params_start and params_end:
for line in lines[params_start:params_end]:
if ':' in line:
param_info = line.split(':', 1)
param_name = param_info[0].strip()
param_desc = param_info[1].strip() if len(param_info) > 1 else ""
# Determine type from function annotations
func_params = func.__annotations__
param_type = "string" # default
if param_name in func_params:
annotation = func_params[param_name]
if annotation == int:
param_type = "integer"
elif annotation == bool:
param_type = "boolean"
elif annotation == float:
param_type = "number"
elif hasattr(annotation, '__origin__'):
# Handle List[str], etc.
if annotation.__origin__ == list:
param_type = "array"
properties[param_name] = {
"type": param_type,
"description": param_desc
}
# Check if parameter has no default value (is required)
import inspect
sig = inspect.signature(func)
if param_name in sig.parameters:
param = sig.parameters[param_name]
if param.default == inspect.Parameter.empty and param_name != 'self':
required.append(param_name)
# Create tool schema
tool = Tool(
name=func.__name__,
description=description,
inputSchema={
"type": "object",
"properties": properties,
"required": required
}
)
return tool
@app.list_tools()
async def list_tools() -> List[Tool]:
"""List all available tools from enabled providers."""
tools = []
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
tool_schema = create_tool_schema(tool_func, provider_name)
tools.append(tool_schema)
logger.info(f"Listed {len(tools)} tools from {len(providers)} providers")
return tools
@app.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Execute a tool from any enabled provider."""
logger.info(f"Calling tool: {name} with arguments: {arguments}")
# Find the tool in providers
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == name:
try:
# Call the tool function
result = await tool_func(**arguments)
# Return as TextContent
return [TextContent(
type="text",
text=str(result)
)]
except Exception as e:
logger.exception(f"Error calling tool {name}")
error_result = f'{{"error": "{str(e)}", "error_type": "{type(e).__name__}"}}'
return [TextContent(
type="text",
text=error_result
)]
# Tool not found
error_msg = f'{{"error": "Tool not found: {name}"}}'
return [TextContent(type="text", text=error_msg)]
async def main():
"""Main entry point for the MCP server."""
global providers, http_client
# Load configuration
config = MCPConfig.from_env()
# Set logging level
logging.getLogger().setLevel(config.log_level)
logger.info("Starting Healthcare API MCP Server")
logger.info(f"Enabled providers: {config.enabled_providers}")
# Create shared HTTP client
http_client = create_http_client()
# Initialize providers
provider_classes = {
"openfda": OpenFDAProvider,
"clinical_guidelines": ClinicalGuidelinesProvider,
"cms_pricing": CMSPricingProvider
}
for provider_name in config.enabled_providers:
if provider_name not in provider_classes:
logger.warning(f"Unknown provider: {provider_name}")
continue
try:
provider_class = provider_classes[provider_name]
# Initialize provider with appropriate parameters
if provider_name == "openfda":
provider = provider_class(http_client, api_key=config.openfda_api_key)
else:
provider = provider_class(http_client)
await provider.initialize()
providers[provider_name] = provider
logger.info(f"Initialized provider: {provider_name}")
except Exception as e:
logger.error(f"Failed to initialize provider {provider_name}: {e}")
if not providers:
logger.error("No providers initialized. Exiting.")
return
logger.info(f"Successfully initialized {len(providers)} provider(s)")
# Run the server with stdio transport
try:
async with stdio_server() as (read_stream, write_stream):
logger.info("MCP server running on stdio")
await app.run(
read_stream,
write_stream,
app.create_initialization_options()
)
finally:
# Cleanup
logger.info("Shutting down...")
for provider in providers.values():
await provider.cleanup()
if http_client:
await http_client.aclose()
if __name__ == "__main__":
asyncio.run(main())