Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from contextlib import asynccontextmanager | |
| from typing import Dict, Optional | |
| from models.base import BaseModelWrapper | |
| from models.model_v1.wrapper import HybridDebertaWrapper | |
| from models.model_v2.wrapper import ModelV2Wrapper | |
| from models.model_v3.wrapper import MultimodalWrapper | |
| AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = { | |
| "Model V1": HybridDebertaWrapper(), | |
| "Model V2": ModelV2Wrapper(), | |
| "Model V3 (Multimodal)": MultimodalWrapper() | |
| } | |
| async def lifespan(app: FastAPI): | |
| print("Initializing available models...") | |
| for name, wrapper in AVAILABLE_MODELS.items(): | |
| try: | |
| print(f"Loading {name}...") | |
| wrapper.load() | |
| print(f"{name} loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load {name}: {e}") | |
| yield | |
| AVAILABLE_MODELS.clear() | |
| app = FastAPI(lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["https://adtrack.onrender.com", "http://localhost:5173"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def list_models(): | |
| return {"models": list(AVAILABLE_MODELS.keys())} | |
| async def predict( | |
| model_name: str = Form(...), | |
| file: Optional[UploadFile] = File(None), # .cha file (Optional) | |
| audio_file: Optional[UploadFile] = File(None), # Audio file (Optional) | |
| segmentation_file: Optional[UploadFile] = File(None) # Segmentation .csv (Optional) | |
| ): | |
| if model_name not in AVAILABLE_MODELS: | |
| raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") | |
| # --- Validation Logic --- | |
| if model_name in ["Model V1", "Model V2"]: | |
| if not file or not file.filename.endswith('.cha'): | |
| raise HTTPException(status_code=400, detail=f"{model_name} requires a .cha file.") | |
| if not file and not audio_file: | |
| raise HTTPException(status_code=400, detail="Please provide input files (CHA or Audio).") | |
| # --- Read Files --- | |
| text_content = b"" | |
| filename = "upload" | |
| if file: | |
| text_content = await file.read() | |
| filename = file.filename | |
| elif audio_file: | |
| filename = audio_file.filename | |
| audio_content = None | |
| if audio_file: | |
| audio_content = await audio_file.read() | |
| segmentation_content = None | |
| if segmentation_file: | |
| segmentation_content = await segmentation_file.read() | |
| # --- Prediction --- | |
| try: | |
| # Pass all contents to the wrapper | |
| result = AVAILABLE_MODELS[model_name].predict( | |
| file_content=text_content, | |
| filename=filename, | |
| audio_content=audio_content, | |
| segmentation_content=segmentation_content | |
| ) | |
| return result | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| print(f"Prediction Error: {e}") | |
| raise HTTPException(status_code=500, detail="Internal Server Error") | |
| def health_check(): | |
| return {"status": "active", "loaded_models": list(AVAILABLE_MODELS.keys())} |