Spaces:
Running
Running
File size: 7,572 Bytes
0d10048 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
"""Healthcare API MCP Server - Public APIs for FDA, Clinical Guidelines, and CMS."""
import asyncio
import logging
from typing import List, Dict, Any
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from .config import MCPConfig
from .core import create_http_client
from .providers import get_provider_class
from .providers.openfda_provider import OpenFDAProvider
from .providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from .providers.cms_pricing_provider import CMSPricingProvider
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Initialize MCP server
app = Server("healthcare-api-mcp")
# Global state
providers = {}
http_client = None
def create_tool_schema(func: callable, provider_name: str) -> Tool:
"""Create MCP tool schema from provider function."""
# Get function metadata
doc = func.__doc__ or ""
lines = [line.strip() for line in doc.split('\n') if line.strip()]
# Extract description (first non-empty line)
description = lines[0] if lines else func.__name__
# Extract parameter info from docstring
# Parse Args section
params_start = None
params_end = None
for i, line in enumerate(lines):
if line.startswith('Args:'):
params_start = i + 1
elif params_start and line.startswith('Returns:'):
params_end = i
break
# Build input schema
properties = {}
required = []
if params_start and params_end:
for line in lines[params_start:params_end]:
if ':' in line:
param_info = line.split(':', 1)
param_name = param_info[0].strip()
param_desc = param_info[1].strip() if len(param_info) > 1 else ""
# Determine type from function annotations
func_params = func.__annotations__
param_type = "string" # default
if param_name in func_params:
annotation = func_params[param_name]
if annotation == int:
param_type = "integer"
elif annotation == bool:
param_type = "boolean"
elif annotation == float:
param_type = "number"
elif hasattr(annotation, '__origin__'):
# Handle List[str], etc.
if annotation.__origin__ == list:
param_type = "array"
properties[param_name] = {
"type": param_type,
"description": param_desc
}
# Check if parameter has no default value (is required)
import inspect
sig = inspect.signature(func)
if param_name in sig.parameters:
param = sig.parameters[param_name]
if param.default == inspect.Parameter.empty and param_name != 'self':
required.append(param_name)
# Create tool schema
tool = Tool(
name=func.__name__,
description=description,
inputSchema={
"type": "object",
"properties": properties,
"required": required
}
)
return tool
@app.list_tools()
async def list_tools() -> List[Tool]:
"""List all available tools from enabled providers."""
tools = []
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
tool_schema = create_tool_schema(tool_func, provider_name)
tools.append(tool_schema)
logger.info(f"Listed {len(tools)} tools from {len(providers)} providers")
return tools
@app.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
"""Execute a tool from any enabled provider."""
logger.info(f"Calling tool: {name} with arguments: {arguments}")
# Find the tool in providers
for provider_name, provider in providers.items():
provider_tools = provider.get_tools()
for tool_func in provider_tools:
if tool_func.__name__ == name:
try:
# Call the tool function
result = await tool_func(**arguments)
# Return as TextContent
return [TextContent(
type="text",
text=str(result)
)]
except Exception as e:
logger.exception(f"Error calling tool {name}")
error_result = f'{{"error": "{str(e)}", "error_type": "{type(e).__name__}"}}'
return [TextContent(
type="text",
text=error_result
)]
# Tool not found
error_msg = f'{{"error": "Tool not found: {name}"}}'
return [TextContent(type="text", text=error_msg)]
async def main():
"""Main entry point for the MCP server."""
global providers, http_client
# Load configuration
config = MCPConfig.from_env()
# Set logging level
logging.getLogger().setLevel(config.log_level)
logger.info("Starting Healthcare API MCP Server")
logger.info(f"Enabled providers: {config.enabled_providers}")
# Create shared HTTP client
http_client = create_http_client()
# Initialize providers
provider_classes = {
"openfda": OpenFDAProvider,
"clinical_guidelines": ClinicalGuidelinesProvider,
"cms_pricing": CMSPricingProvider
}
for provider_name in config.enabled_providers:
if provider_name not in provider_classes:
logger.warning(f"Unknown provider: {provider_name}")
continue
try:
provider_class = provider_classes[provider_name]
# Initialize provider with appropriate parameters
if provider_name == "openfda":
provider = provider_class(http_client, api_key=config.openfda_api_key)
else:
provider = provider_class(http_client)
await provider.initialize()
providers[provider_name] = provider
logger.info(f"Initialized provider: {provider_name}")
except Exception as e:
logger.error(f"Failed to initialize provider {provider_name}: {e}")
if not providers:
logger.error("No providers initialized. Exiting.")
return
logger.info(f"Successfully initialized {len(providers)} provider(s)")
# Run the server with stdio transport
try:
async with stdio_server() as (read_stream, write_stream):
logger.info("MCP server running on stdio")
await app.run(
read_stream,
write_stream,
app.create_initialization_options()
)
finally:
# Cleanup
logger.info("Shutting down...")
for provider in providers.values():
await provider.cleanup()
if http_client:
await http_client.aclose()
if __name__ == "__main__":
asyncio.run(main())
|