adtrack-v2 / main.py
cracker0935's picture
add mode to model 3
e824b96
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from typing import Dict, Optional
from models.base import BaseModelWrapper
from models.model_v1.wrapper import HybridDebertaWrapper
from models.model_v2.wrapper import ModelV2Wrapper
from models.model_v3.wrapper import MultimodalWrapper
AVAILABLE_MODELS: Dict[str, BaseModelWrapper] = {
"Model V1": HybridDebertaWrapper(),
"Model V2": ModelV2Wrapper(),
"Model V3 (Multimodal)": MultimodalWrapper()
}
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Initializing available models...")
for name, wrapper in AVAILABLE_MODELS.items():
try:
print(f"Loading {name}...")
wrapper.load()
print(f"{name} loaded successfully.")
except Exception as e:
print(f"Failed to load {name}: {e}")
yield
AVAILABLE_MODELS.clear()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["https://adtrack.onrender.com", "http://localhost:5173"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/models")
def list_models():
return {"models": list(AVAILABLE_MODELS.keys())}
@app.post("/predict")
async def predict(
model_name: str = Form(...),
file: Optional[UploadFile] = File(None), # .cha file (Optional)
audio_file: Optional[UploadFile] = File(None), # Audio file (Optional)
segmentation_file: Optional[UploadFile] = File(None) # Segmentation .csv (Optional)
):
if model_name not in AVAILABLE_MODELS:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# --- Validation Logic ---
if model_name in ["Model V1", "Model V2"]:
if not file or not file.filename.endswith('.cha'):
raise HTTPException(status_code=400, detail=f"{model_name} requires a .cha file.")
if not file and not audio_file:
raise HTTPException(status_code=400, detail="Please provide input files (CHA or Audio).")
# --- Read Files ---
text_content = b""
filename = "upload"
if file:
text_content = await file.read()
filename = file.filename
elif audio_file:
filename = audio_file.filename
audio_content = None
if audio_file:
audio_content = await audio_file.read()
segmentation_content = None
if segmentation_file:
segmentation_content = await segmentation_file.read()
# --- Prediction ---
try:
# Pass all contents to the wrapper
result = AVAILABLE_MODELS[model_name].predict(
file_content=text_content,
filename=filename,
audio_content=audio_content,
segmentation_content=segmentation_content
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
print(f"Prediction Error: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error")
@app.get("/health")
def health_check():
return {"status": "active", "loaded_models": list(AVAILABLE_MODELS.keys())}