Spaces:
Running
Running
| import os | |
| import torch | |
| from loguru import logger | |
| from pydantic import BaseModel, Field | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from sentence_transformers import CrossEncoder | |
| from typing import List, Optional | |
| # Initialize FastAPI app with documentation metadata | |
| app = FastAPI( | |
| title="Document Reranker API", | |
| description="An API for reranking documents using a CrossEncoder model.", | |
| version="1.0", | |
| docs_url="/docs", # Swagger UI | |
| redoc_url="/redoc", # ReDoc UI | |
| ) | |
| # Enable CORS (optional but useful for frontend integration) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins (modify as needed) | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Device selection | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.warning( | |
| f"Using device: {DEVICE} ({'GPU: ' + torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'Running on CPU'})" | |
| ) | |
| # Ensure a writable cache directory | |
| os.makedirs("models_cache", exist_ok=True) | |
| # Load the model at startup to avoid reloading for each request | |
| try: | |
| model = CrossEncoder( | |
| "jinaai/jina-reranker-v1-turbo-en", | |
| trust_remote_code=True, | |
| device=DEVICE, | |
| cache_dir="models_cache", | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load model: {e}") | |
| raise RuntimeError("Model loading failed. Check logs for details.") | |
| class RerankerRequest(BaseModel): | |
| query: str = Field(..., description="The search query string") | |
| documents: List[str] = Field(..., description="List of documents to rerank") | |
| return_documents: bool = Field( | |
| True, description="Whether to return document content in results" | |
| ) | |
| top_k: int = Field(3, description="Number of top results to return") | |
| class RankedResult(BaseModel): | |
| score: float | |
| index: int | |
| document: Optional[str] = None | |
| class RerankerResponse(BaseModel): | |
| results: List[RankedResult] | |
| async def rerank_documents(request: RerankerRequest): | |
| """ | |
| Reranks the given list of documents based on their relevance to the query. | |
| - **query**: The input query string. | |
| - **documents**: A list of documents to be reranked. | |
| - **return_documents**: Whether to include document content in results. | |
| - **top_k**: Number of top-ranked documents to return. | |
| Returns: | |
| - A list of ranked documents with scores and indexes. | |
| """ | |
| try: | |
| # Prepare model input | |
| results = model.rank( | |
| request.query, | |
| request.documents, | |
| return_documents=request.return_documents, | |
| top_k=request.top_k, | |
| ) | |
| # Format the results based on the model's output | |
| formatted_results = [ | |
| RankedResult( | |
| score=result["score"], | |
| index=result["corpus_id"], | |
| document=result["text"] if request.return_documents else None, | |
| ) | |
| for result in results | |
| ] | |
| # Format results | |
| return RerankerResponse(results=formatted_results) | |
| except Exception as e: | |
| logger.error(f"Error in reranking: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error in reranking: {str(e)}") | |
| # Run the FastAPI app with Uvicorn | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |