from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager from typing import Dict, Optional from models.base import BaseModelWrapper from models.model_v1.wrapper import HybridDebertaWrapper from models.model_v2.wrapper import ModelV2Wrapper from models.model_v3.wrapper import MultimodalWrapper AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = { "Model V1": HybridDebertaWrapper(), "Model V2": ModelV2Wrapper(), "Model V3 (Multimodal)": MultimodalWrapper() } @asynccontextmanager async def lifespan(app: FastAPI): print("Initializing available models...") for name, wrapper in AVAILABLE_MODELS.items(): try: print(f"Loading {name}...") wrapper.load() print(f"{name} loaded successfully.") except Exception as e: print(f"Failed to load {name}: {e}") yield AVAILABLE_MODELS.clear() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["https://adtrack.onrender.com", "http://localhost:5173"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/models") def list_models(): return {"models": list(AVAILABLE_MODELS.keys())} @app.post("/predict") async def predict( model_name: str = Form(...), file: Optional[UploadFile] = File(None), # .cha file (Optional) audio_file: Optional[UploadFile] = File(None), # Audio file (Optional) segmentation_file: Optional[UploadFile] = File(None) # Segmentation .csv (Optional) ): if model_name not in AVAILABLE_MODELS: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") # --- Validation Logic --- if model_name in ["Model V1", "Model V2"]: if not file or not file.filename.endswith('.cha'): raise HTTPException(status_code=400, detail=f"{model_name} requires a .cha file.") if not file and not audio_file: raise HTTPException(status_code=400, detail="Please provide input files (CHA or Audio).") # --- Read Files --- text_content = b"" filename = "upload" if file: text_content = await file.read() filename = file.filename elif audio_file: filename = audio_file.filename audio_content = None if audio_file: audio_content = await audio_file.read() segmentation_content = None if segmentation_file: segmentation_content = await segmentation_file.read() # --- Prediction --- try: # Pass all contents to the wrapper result = AVAILABLE_MODELS[model_name].predict( file_content=text_content, filename=filename, audio_content=audio_content, segmentation_content=segmentation_content ) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: print(f"Prediction Error: {e}") raise HTTPException(status_code=500, detail="Internal Server Error") @app.get("/health") def health_check(): return {"status": "active", "loaded_models": list(AVAILABLE_MODELS.keys())}