Spaces:
Running
Running
| import base64 | |
| import io | |
| import logging | |
| from typing import List, Optional | |
| import torch | |
| import torchaudio | |
| import uvicorn | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from generator import load_csm_1b, Segment | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="CSM 1B API", | |
| description="API for Sesame's Conversational Speech Model", | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| generator = None | |
| class SegmentRequest(BaseModel): | |
| speaker: int | |
| text: str | |
| audio_base64: Optional[str] = None | |
| class GenerateAudioRequest(BaseModel): | |
| text: str | |
| speaker: int | |
| context: List[SegmentRequest] = [] | |
| max_audio_length_ms: float = 10000 | |
| temperature: float = 0.9 | |
| topk: int = 50 | |
| class AudioResponse(BaseModel): | |
| audio_base64: str | |
| sample_rate: int | |
| async def startup_event(): | |
| global generator | |
| logger.info("Loading CSM 1B model...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if device == "cpu": | |
| logger.warning("GPU not available. Using CPU, performance may be slow!") | |
| try: | |
| generator = load_csm_1b(device=device) | |
| logger.info(f"Model loaded successfully on device: {device}") | |
| except Exception as e: | |
| logger.error(f"Could not load model: {str(e)}") | |
| raise e | |
| async def generate_audio(request: GenerateAudioRequest): | |
| global generator | |
| if generator is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.") | |
| try: | |
| context_segments = [] | |
| for segment in request.context: | |
| if segment.audio_base64: | |
| audio_bytes = base64.b64decode(segment.audio_base64) | |
| audio_buffer = io.BytesIO(audio_bytes) | |
| audio_tensor, sample_rate = torchaudio.load(audio_buffer) | |
| audio_tensor = torchaudio.functional.resample( | |
| audio_tensor.squeeze(0), | |
| orig_freq=sample_rate, | |
| new_freq=generator.sample_rate | |
| ) | |
| else: | |
| audio_tensor = torch.zeros(0, dtype=torch.float32) | |
| context_segments.append( | |
| Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor) | |
| ) | |
| audio = generator.generate( | |
| text=request.text, | |
| speaker=request.speaker, | |
| context=context_segments, | |
| max_audio_length_ms=request.max_audio_length_ms, | |
| temperature=request.temperature, | |
| topk=request.topk, | |
| ) | |
| buffer = io.BytesIO() | |
| torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav") | |
| # torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) | |
| buffer.seek(0) | |
| # audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") | |
| return AudioResponse( | |
| content=buffer.read(), | |
| media_type="audio/wav", | |
| headers={"Content-Disposition": "attachment; filename=audio.wav"} | |
| ) | |
| except Exception as e: | |
| logger.error(f"error when building audio: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}") | |
| async def health_check(): | |
| if generator is None: | |
| return {"status": "not_ready", "message": "Model is loading"} | |
| return {"status": "ready", "message": "API is ready to serve"} | |