File size: 17,720 Bytes
536b5af
51d4fe4
536b5af
 
51d4fe4
 
 
 
 
82f8d33
536b5af
 
 
51d4fe4
536b5af
 
51d4fe4
 
cbcd58d
 
 
 
 
51d4fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536b5af
 
 
 
51d4fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82f8d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536b5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51d4fe4
e880bf8
 
51d4fe4
 
 
536b5af
 
51d4fe4
 
e880bf8
536b5af
 
 
 
 
51d4fe4
 
 
 
e880bf8
 
 
 
 
 
 
 
 
51d4fe4
 
 
 
 
 
 
 
 
 
 
82f8d33
51d4fe4
 
 
 
 
82f8d33
 
 
 
51d4fe4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e880bf8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
"""MCP Streamable HTTP Server for Healthcare API MCP - HuggingFace Spaces deployment.

Implements MCP Streamable HTTP transport (JSON-RPC 2.0).
Specification: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports
"""

import asyncio
import json
import logging
import inspect
import uuid
from typing import AsyncGenerator, get_type_hints, get_origin, get_args, Optional, Dict, Any
from datetime import datetime

from fastapi import FastAPI, Request, Header
from fastapi.responses import StreamingResponse, JSONResponse
from starlette.middleware.cors import CORSMiddleware

from config import MCPConfig
from core import create_http_client
from providers.openfda_provider import OpenFDAProvider
from providers.clinical_guidelines_provider import ClinicalGuidelinesProvider
from providers.cms_pricing_provider import CMSPricingProvider

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI(
    title="Healthcare API MCP",
    description="Model Context Protocol server for public healthcare data APIs",
    version="1.0.0"
)

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global state
providers = {}
http_client = None
sessions: Dict[str, Dict[str, Any]] = {}  # Session ID -> session data

# MCP Protocol version
MCP_PROTOCOL_VERSION = "2024-11-05"


@app.on_event("startup")
async def startup():
    """Initialize providers on startup."""
    global http_client, providers

    config = MCPConfig.from_env()
    logger.info(f"Starting Healthcare API MCP with providers: {config.enabled_providers}")

    http_client = create_http_client()

    if "openfda" in config.enabled_providers:
        providers["openfda"] = OpenFDAProvider(http_client, config.openfda_api_key)
        await providers["openfda"].initialize()
        logger.info("✅ OpenFDA provider initialized")

    if "clinical_guidelines" in config.enabled_providers:
        providers["clinical_guidelines"] = ClinicalGuidelinesProvider(http_client)
        await providers["clinical_guidelines"].initialize()
        logger.info("✅ Clinical Guidelines provider initialized")

    if "cms_pricing" in config.enabled_providers:
        providers["cms_pricing"] = CMSPricingProvider(http_client)
        await providers["cms_pricing"].initialize()
        logger.info("✅ CMS Pricing provider initialized")

    logger.info(f"Healthcare API MCP ready with {len(providers)} provider(s)")


@app.on_event("shutdown")
async def shutdown():
    """Cleanup on shutdown."""
    global http_client, providers

    logger.info("Shutting down Healthcare API MCP...")

    for provider in providers.values():
        await provider.cleanup()

    if http_client:
        await http_client.aclose()


def build_tool_schema(tool_func) -> dict:
    """
    Build proper MCP tool schema with inputSchema from function signature.

    Extracts:
    - Function description from docstring
    - Parameter types from type hints
    - Required vs optional parameters (based on defaults)
    - Parameter descriptions from docstring
    """
    # Get function signature
    sig = inspect.signature(tool_func)

    # Get type hints
    try:
        hints = get_type_hints(tool_func)
    except:
        hints = {}

    # Parse docstring
    doc = tool_func.__doc__ or ""
    description = doc.split("Args:")[0].strip() if "Args:" in doc else doc.strip()

    # Parse parameter descriptions from docstring
    param_descriptions = {}
    if "Args:" in doc:
        args_section = doc.split("Args:")[1].split("Returns:")[0]
        for line in args_section.split("\n"):
            line = line.strip()
            if ":" in line and not line.startswith("return"):
                param_name = line.split(":")[0].strip()
                param_desc = line.split(":", 1)[1].strip()
                param_descriptions[param_name] = param_desc

    # Build input schema
    properties = {}
    required = []

    for param_name, param in sig.parameters.items():
        # Skip 'self' parameter
        if param_name == "self":
            continue

        # Determine type
        param_type = "string"  # default
        if param_name in hints:
            hint = hints[param_name]
            # Map Python types to JSON Schema types
            if hint == str:
                param_type = "string"
            elif hint == int:
                param_type = "integer"
            elif hint == float:
                param_type = "number"
            elif hint == bool:
                param_type = "boolean"
            elif get_origin(hint) == list:
                param_type = "array"
            elif get_origin(hint) == dict:
                param_type = "object"

        # Build property schema
        prop_schema = {"type": param_type}
        if param_name in param_descriptions:
            prop_schema["description"] = param_descriptions[param_name]

        properties[param_name] = prop_schema

        # Determine if required (no default value)
        if param.default == inspect.Parameter.empty:
            required.append(param_name)

    # Build full tool schema
    schema = {
        "name": tool_func.__name__,
        "description": description,
        "inputSchema": {
            "type": "object",
            "properties": properties
        }
    }

    if required:
        schema["inputSchema"]["required"] = required

    return schema


# ============================================================================
# JSON-RPC 2.0 Message Handlers
# ============================================================================

def create_jsonrpc_response(result: Any, request_id: Any) -> dict:
    """Create a JSON-RPC 2.0 success response."""
    return {
        "jsonrpc": "2.0",
        "result": result,
        "id": request_id
    }


def create_jsonrpc_error(code: int, message: str, request_id: Any, data: Any = None) -> dict:
    """Create a JSON-RPC 2.0 error response."""
    error = {
        "code": code,
        "message": message
    }
    if data is not None:
        error["data"] = data

    return {
        "jsonrpc": "2.0",
        "error": error,
        "id": request_id
    }


async def handle_initialize(params: dict, session_id: str) -> dict:
    """Handle initialize method - capability negotiation."""
    client_info = params.get("clientInfo", {})
    protocol_version = params.get("protocolVersion", MCP_PROTOCOL_VERSION)

    logger.info(f"Initializing session {session_id} with client: {client_info.get('name', 'unknown')}")

    # Store session data
    sessions[session_id] = {
        "initialized": True,
        "client_info": client_info,
        "protocol_version": protocol_version,
        "created_at": datetime.utcnow().isoformat()
    }

    # Return server capabilities
    return {
        "protocolVersion": MCP_PROTOCOL_VERSION,
        "capabilities": {
            "tools": {
                "listChanged": False  # We don't support dynamic tool updates
            }
        },
        "serverInfo": {
            "name": "Healthcare API MCP",
            "version": "1.0.0"
        },
        "instructions": "This MCP provides access to public healthcare data APIs including OpenFDA (drug safety, labels, recalls), clinical guidelines (USPSTF, CDC), and CMS pricing data."
    }


async def handle_tools_list(params: dict) -> dict:
    """Handle tools/list method - return available tools."""
    tools = []

    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            schema = build_tool_schema(tool_func)
            tools.append(schema)

    return {"tools": tools}


async def handle_tools_call(params: dict) -> dict:
    """Handle tools/call method - execute a tool."""
    tool_name = params.get("name")
    arguments = params.get("arguments", {})

    if not tool_name:
        raise ValueError("Missing 'name' parameter in tools/call")

    logger.info(f"Calling tool: {tool_name} with args: {arguments}")

    # Find and execute tool
    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            if tool_func.__name__ == tool_name:
                try:
                    result = await tool_func(**arguments)

                    # Return MCP tool response format
                    return {
                        "content": [
                            {
                                "type": "text",
                                "text": result
                            }
                        ]
                    }

                except TypeError as e:
                    # Handle missing required arguments
                    error_msg = str(e)
                    if "missing" in error_msg and "required" in error_msg:
                        logger.error(f"Missing required arguments for {tool_name}: {error_msg}")
                        raise ValueError(f"Missing required arguments. {error_msg}. Provided arguments: {list(arguments.keys())}")
                    else:
                        raise

    raise ValueError(f"Tool not found: {tool_name}")


async def handle_jsonrpc_request(request_data: dict, session_id: Optional[str] = None) -> dict:
    """Handle a JSON-RPC 2.0 request."""
    # Validate JSON-RPC format
    if request_data.get("jsonrpc") != "2.0":
        return create_jsonrpc_error(-32600, "Invalid Request: jsonrpc must be '2.0'", request_data.get("id"))

    method = request_data.get("method")
    params = request_data.get("params", {})
    request_id = request_data.get("id")

    if not method:
        return create_jsonrpc_error(-32600, "Invalid Request: missing method", request_id)

    try:
        # Route to appropriate handler
        if method == "initialize":
            # Create session ID if not exists
            if not session_id:
                session_id = str(uuid.uuid4())
            result = await handle_initialize(params, session_id)
            # Return response with session ID
            response = create_jsonrpc_response(result, request_id)
            response["_session_id"] = session_id  # Include for header
            return response

        elif method == "tools/list":
            result = await handle_tools_list(params)
            return create_jsonrpc_response(result, request_id)

        elif method == "tools/call":
            result = await handle_tools_call(params)
            return create_jsonrpc_response(result, request_id)

        else:
            return create_jsonrpc_error(-32601, f"Method not found: {method}", request_id)

    except ValueError as e:
        return create_jsonrpc_error(-32602, f"Invalid params: {str(e)}", request_id)
    except Exception as e:
        logger.exception(f"Error handling {method}")
        return create_jsonrpc_error(-32603, f"Internal error: {str(e)}", request_id)


# ============================================================================
# MCP Streamable HTTP Endpoint
# ============================================================================

@app.post("/mcp")
async def mcp_endpoint(
    request: Request,
    mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
):
    """
    MCP Streamable HTTP endpoint (JSON-RPC 2.0).

    Handles:
    - initialize: Capability negotiation
    - tools/list: List available tools
    - tools/call: Execute a tool
    """
    try:
        body = await request.json()
    except Exception as e:
        return JSONResponse(
            content=create_jsonrpc_error(-32700, f"Parse error: {str(e)}", None),
            status_code=400
        )

    # Handle JSON-RPC request
    response = await handle_jsonrpc_request(body, mcp_session_id)

    # Extract session ID if present (from initialize)
    session_id = response.pop("_session_id", None)

    # Create response with session header if needed
    headers = {}
    if session_id:
        headers["Mcp-Session-Id"] = session_id

    return JSONResponse(content=response, headers=headers)


# ============================================================================
# Legacy REST Endpoints (for backward compatibility)
# ============================================================================

@app.get("/")
async def root_info():
    """Root endpoint with server info (GET only)."""
    return {
        "name": "Healthcare API MCP",
        "version": "1.0.0",
        "protocol": "MCP Streamable HTTP (JSON-RPC 2.0)",
        "protocolVersion": MCP_PROTOCOL_VERSION,
        "providers": list(providers.keys()),
        "endpoints": {
            "mcp": "/ or /mcp (POST) - MCP Streamable HTTP endpoint",
            "legacy": {
                "tools": "/mcp/tools (GET) - List tools (legacy)",
                "call": "/mcp/call (POST) - Execute tool (legacy)",
                "stream": "/mcp/stream (POST) - Stream results (legacy)"
            }
        }
    }


@app.post("/")
async def root_mcp(
    request: Request,
    mcp_session_id: Optional[str] = Header(None, alias="Mcp-Session-Id")
):
    """Root MCP endpoint - same as /mcp for compatibility."""
    return await mcp_endpoint(request, mcp_session_id)


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "providers": {name: "active" for name in providers.keys()}
    }


@app.get("/mcp/tools")
async def list_tools():
    """List all available MCP tools with proper JSON Schema."""
    tools = []

    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            # Build proper tool schema
            schema = build_tool_schema(tool_func)
            schema["provider"] = provider_name
            tools.append(schema)

    return {"tools": tools, "total": len(tools)}


@app.post("/mcp/call")
async def call_tool(request: Request):
    """Call an MCP tool and return result."""
    body = await request.json()
    tool_name = body.get("tool")
    arguments = body.get("arguments", {})

    if not tool_name:
        return {"error": "Missing 'tool' parameter"}

    logger.info(f"Calling tool: {tool_name} with args: {arguments}")

    # Find and execute tool
    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            if tool_func.__name__ == tool_name:
                try:
                    result = await tool_func(**arguments)
                    return {
                        "success": True,
                        "tool": tool_name,
                        "result": result
                    }
                except Exception as e:
                    logger.exception(f"Error calling tool {tool_name}")
                    return {
                        "success": False,
                        "tool": tool_name,
                        "error": str(e),
                        "error_type": type(e).__name__
                    }

    return {
        "success": False,
        "tool": tool_name,
        "error": f"Tool not found: {tool_name}"
    }


async def stream_tool_result(tool_name: str, arguments: dict) -> AsyncGenerator[str, None]:
    """Stream tool execution results as JSON chunks."""
    # Send initial acknowledgment
    yield json.dumps({
        "type": "start",
        "tool": tool_name,
        "arguments": arguments
    }) + "\n"

    # Find and execute tool
    tool_found = False
    for provider_name, provider in providers.items():
        provider_tools = provider.get_tools()
        for tool_func in provider_tools:
            if tool_func.__name__ == tool_name:
                tool_found = True
                try:
                    # Execute tool
                    result = await tool_func(**arguments)

                    # Stream result
                    yield json.dumps({
                        "type": "result",
                        "success": True,
                        "data": result
                    }) + "\n"

                except Exception as e:
                    logger.exception(f"Error executing tool {tool_name}")
                    yield json.dumps({
                        "type": "error",
                        "success": False,
                        "error": str(e),
                        "error_type": type(e).__name__
                    }) + "\n"

                break
        if tool_found:
            break

    if not tool_found:
        yield json.dumps({
            "type": "error",
            "success": False,
            "error": f"Tool not found: {tool_name}"
        }) + "\n"

    # Send completion
    yield json.dumps({"type": "complete"}) + "\n"


@app.post("/mcp/stream")
async def stream_tool_call(request: Request):
    """Stream tool execution results using HTTP chunked transfer encoding."""
    body = await request.json()
    tool_name = body.get("tool")
    arguments = body.get("arguments", {})

    if not tool_name:
        return {"error": "Missing 'tool' parameter"}

    logger.info(f"Streaming tool: {tool_name} with args: {arguments}")

    return StreamingResponse(
        stream_tool_result(tool_name, arguments),
        media_type="application/x-ndjson",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
        }
    )


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        app,
        host="0.0.0.0",
        port=7860,
        log_level="info"
    )