"""Healthcare API MCP Server - Public APIs for FDA, Clinical Guidelines, and CMS.""" import asyncio import logging from typing import List, Dict, Any from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import Tool, TextContent from .config import MCPConfig from .core import create_http_client from .providers import get_provider_class from .providers.openfda_provider import OpenFDAProvider from .providers.clinical_guidelines_provider import ClinicalGuidelinesProvider from .providers.cms_pricing_provider import CMSPricingProvider # Setup logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) # Initialize MCP server app = Server("healthcare-api-mcp") # Global state providers = {} http_client = None def create_tool_schema(func: callable, provider_name: str) -> Tool: """Create MCP tool schema from provider function.""" # Get function metadata doc = func.__doc__ or "" lines = [line.strip() for line in doc.split('\n') if line.strip()] # Extract description (first non-empty line) description = lines[0] if lines else func.__name__ # Extract parameter info from docstring # Parse Args section params_start = None params_end = None for i, line in enumerate(lines): if line.startswith('Args:'): params_start = i + 1 elif params_start and line.startswith('Returns:'): params_end = i break # Build input schema properties = {} required = [] if params_start and params_end: for line in lines[params_start:params_end]: if ':' in line: param_info = line.split(':', 1) param_name = param_info[0].strip() param_desc = param_info[1].strip() if len(param_info) > 1 else "" # Determine type from function annotations func_params = func.__annotations__ param_type = "string" # default if param_name in func_params: annotation = func_params[param_name] if annotation == int: param_type = "integer" elif annotation == bool: param_type = "boolean" elif annotation == float: param_type = "number" elif hasattr(annotation, '__origin__'): # Handle List[str], etc. if annotation.__origin__ == list: param_type = "array" properties[param_name] = { "type": param_type, "description": param_desc } # Check if parameter has no default value (is required) import inspect sig = inspect.signature(func) if param_name in sig.parameters: param = sig.parameters[param_name] if param.default == inspect.Parameter.empty and param_name != 'self': required.append(param_name) # Create tool schema tool = Tool( name=func.__name__, description=description, inputSchema={ "type": "object", "properties": properties, "required": required } ) return tool @app.list_tools() async def list_tools() -> List[Tool]: """List all available tools from enabled providers.""" tools = [] for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: tool_schema = create_tool_schema(tool_func, provider_name) tools.append(tool_schema) logger.info(f"Listed {len(tools)} tools from {len(providers)} providers") return tools @app.call_tool() async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: """Execute a tool from any enabled provider.""" logger.info(f"Calling tool: {name} with arguments: {arguments}") # Find the tool in providers for provider_name, provider in providers.items(): provider_tools = provider.get_tools() for tool_func in provider_tools: if tool_func.__name__ == name: try: # Call the tool function result = await tool_func(**arguments) # Return as TextContent return [TextContent( type="text", text=str(result) )] except Exception as e: logger.exception(f"Error calling tool {name}") error_result = f'{{"error": "{str(e)}", "error_type": "{type(e).__name__}"}}' return [TextContent( type="text", text=error_result )] # Tool not found error_msg = f'{{"error": "Tool not found: {name}"}}' return [TextContent(type="text", text=error_msg)] async def main(): """Main entry point for the MCP server.""" global providers, http_client # Load configuration config = MCPConfig.from_env() # Set logging level logging.getLogger().setLevel(config.log_level) logger.info("Starting Healthcare API MCP Server") logger.info(f"Enabled providers: {config.enabled_providers}") # Create shared HTTP client http_client = create_http_client() # Initialize providers provider_classes = { "openfda": OpenFDAProvider, "clinical_guidelines": ClinicalGuidelinesProvider, "cms_pricing": CMSPricingProvider } for provider_name in config.enabled_providers: if provider_name not in provider_classes: logger.warning(f"Unknown provider: {provider_name}") continue try: provider_class = provider_classes[provider_name] # Initialize provider with appropriate parameters if provider_name == "openfda": provider = provider_class(http_client, api_key=config.openfda_api_key) else: provider = provider_class(http_client) await provider.initialize() providers[provider_name] = provider logger.info(f"Initialized provider: {provider_name}") except Exception as e: logger.error(f"Failed to initialize provider {provider_name}: {e}") if not providers: logger.error("No providers initialized. Exiting.") return logger.info(f"Successfully initialized {len(providers)} provider(s)") # Run the server with stdio transport try: async with stdio_server() as (read_stream, write_stream): logger.info("MCP server running on stdio") await app.run( read_stream, write_stream, app.create_initialization_options() ) finally: # Cleanup logger.info("Shutting down...") for provider in providers.values(): await provider.cleanup() if http_client: await http_client.aclose() if __name__ == "__main__": asyncio.run(main())