File size: 7,572 Bytes
0d10048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""Healthcare API MCP Server - Public APIs for FDA, Clinical Guidelines, and CMS."""

import asyncio
import logging
from typing import List, Dict, Any

from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent

from .config import MCPConfig
from .core import create_http_client
from .providers import get_provider_class
from .providers.openfda_provider import OpenFDAProvider
from .providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from .providers.cms_pricing_provider import CMSPricingProvider

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Initialize MCP server
app = Server("healthcare-api-mcp")

# Global state
providers = {}
http_client = None


def create_tool_schema(func: callable, provider_name: str) -> Tool:
    """Create MCP tool schema from provider function."""
    # Get function metadata
    doc = func.__doc__ or ""
    lines = [line.strip() for line in doc.split('\n') if line.strip()]

    # Extract description (first non-empty line)
    description = lines[0] if lines else func.__name__

    # Extract parameter info from docstring
    # Parse Args section
    params_start = None
    params_end = None
    for i, line in enumerate(lines):
        if line.startswith('Args:'):
            params_start = i + 1
        elif params_start and line.startswith('Returns:'):
            params_end = i
            break

    # Build input schema
    properties = {}
    required = []

    if params_start and params_end:
        for line in lines[params_start:params_end]:
            if ':' in line:
                param_info = line.split(':', 1)
                param_name = param_info[0].strip()
                param_desc = param_info[1].strip() if len(param_info) > 1 else ""

                # Determine type from function annotations
                func_params = func.__annotations__
                param_type = "string"  # default

                if param_name in func_params:
                    annotation = func_params[param_name]
                    if annotation == int:
                        param_type = "integer"
                    elif annotation == bool:
                        param_type = "boolean"
                    elif annotation == float:
                        param_type = "number"
                    elif hasattr(annotation, '__origin__'):
                        # Handle List[str], etc.
                        if annotation.__origin__ == list:
                            param_type = "array"

                properties[param_name] = {
                    "type": param_type,
                    "description": param_desc
                }

                # Check if parameter has no default value (is required)
                import inspect
                sig = inspect.signature(func)
                if param_name in sig.parameters:
                    param = sig.parameters[param_name]
                    if param.default == inspect.Parameter.empty and param_name != 'self':
                        required.append(param_name)

    # Create tool schema
    tool = Tool(
        name=func.__name__,
        description=description,
        inputSchema={
            "type": "object",
            "properties": properties,
            "required": required
        }
    )

    return tool


@app.list_tools()
async def list_tools() -> List[Tool]:
    """List all available tools from enabled providers."""
    tools = []

    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            tool_schema = create_tool_schema(tool_func, provider_name)
            tools.append(tool_schema)

    logger.info(f"Listed {len(tools)} tools from {len(providers)} providers")
    return tools


@app.call_tool()
async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]:
    """Execute a tool from any enabled provider."""
    logger.info(f"Calling tool: {name} with arguments: {arguments}")

    # Find the tool in providers
    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            if tool_func.__name__ == name:
                try:
                    # Call the tool function
                    result = await tool_func(**arguments)

                    # Return as TextContent
                    return [TextContent(
                        type="text",
                        text=str(result)
                    )]
                except Exception as e:
                    logger.exception(f"Error calling tool {name}")
                    error_result = f'{{"error": "{str(e)}", "error_type": "{type(e).__name__}"}}'
                    return [TextContent(
                        type="text",
                        text=error_result
                    )]

    # Tool not found
    error_msg = f'{{"error": "Tool not found: {name}"}}'
    return [TextContent(type="text", text=error_msg)]


async def main():
    """Main entry point for the MCP server."""
    global providers, http_client

    # Load configuration
    config = MCPConfig.from_env()

    # Set logging level
    logging.getLogger().setLevel(config.log_level)

    logger.info("Starting Healthcare API MCP Server")
    logger.info(f"Enabled providers: {config.enabled_providers}")

    # Create shared HTTP client
    http_client = create_http_client()

    # Initialize providers
    provider_classes = {
        "openfda": OpenFDAProvider,
        "clinical_guidelines": ClinicalGuidelinesProvider,
        "cms_pricing": CMSPricingProvider
    }

    for provider_name in config.enabled_providers:
        if provider_name not in provider_classes:
            logger.warning(f"Unknown provider: {provider_name}")
            continue

        try:
            provider_class = provider_classes[provider_name]

            # Initialize provider with appropriate parameters
            if provider_name == "openfda":
                provider = provider_class(http_client, api_key=config.openfda_api_key)
            else:
                provider = provider_class(http_client)

            await provider.initialize()
            providers[provider_name] = provider
            logger.info(f"Initialized provider: {provider_name}")
        except Exception as e:
            logger.error(f"Failed to initialize provider {provider_name}: {e}")

    if not providers:
        logger.error("No providers initialized. Exiting.")
        return

    logger.info(f"Successfully initialized {len(providers)} provider(s)")

    # Run the server with stdio transport
    try:
        async with stdio_server() as (read_stream, write_stream):
            logger.info("MCP server running on stdio")
            await app.run(
                read_stream,
                write_stream,
                app.create_initialization_options()
            )
    finally:
        # Cleanup
        logger.info("Shutting down...")
        for provider in providers.values():
            await provider.cleanup()
        if http_client:
            await http_client.aclose()


if __name__ == "__main__":
    asyncio.run(main())