from fastapi import FastAPI, Query from pydantic import BaseModel from typing import Literal from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline import math PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment" FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" app = FastAPI() clf = None loaded_model_id = None class Payload(BaseModel): sentences: list[str] @app.get("/healthz") def health(): return {"status": "healthy", "model": loaded_model_id} def load_pipeline(): """ Robust loader: - Avoid meta-tensor issue by forcing low_cpu_mem_usage=False - No device_map; keep on CPU - Cache the pipeline in the global `clf` """ global clf, loaded_model_id if clf is not None: return clf for model_id in (PRIMARY_MODEL, FALLBACK_MODEL): try: tok = AutoTokenizer.from_pretrained(model_id) # use_fast default is fine mdl = AutoModelForSequenceClassification.from_pretrained( model_id, low_cpu_mem_usage=False, # <-- important to avoid meta tensors trust_remote_code=False ) loaded_model_id = model_id clf = pipeline( "text-classification", model=mdl, tokenizer=tok, device=-1 # CPU ) return clf except Exception as e: print(f"Failed to load {model_id}: {e}") raise RuntimeError("No sentiment model could be loaded") def compute_score(pos: float, neg: float, neu: float, mode: str) -> float: if mode == "debias": # original denom = max(1e-6, 1.0 - neu) return (pos - neg) / denom elif mode == "raw": # conservative return pos - neg elif mode == "logit": # optional slightly squashed # difference of logits, then tanh to clamp to [-1,1] import math as _m lp = _m.log(max(1e-9, pos)) - _m.log(max(1e-9, 1-pos)) ln = _m.log(max(1e-9, neg)) - _m.log(max(1e-9, 1-neg)) return _m.tanh((lp - ln) / 4.0) else: return pos - neg def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float): m = {s["label"].lower(): float(s["score"]) for s in scores} keys = set(m.keys()) neg = neu = pos = 0.0 detected_binary = False if {"negative","positive"} & keys: neg, pos = m.get("negative",0.0), m.get("positive",0.0) neu = m.get("neutral",0.0) detected_binary = ("neutral" not in m) or (len(m) == 2) elif any(k.startswith("label_") for k in keys): neg = m.get("label_0", 0.0) if "label_2" in m or len(m) >= 3: neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False else: pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True else: for k,v in m.items(): if "pos" in k: pos = v if "neg" in k: neg = v if "neu" in k: neu = v detected_binary = (neu == 0.0) is_binary = detected_binary if binary_hint is None else bool(binary_hint) if is_binary: neu = 0.0 score = compute_score(pos, neg, neu, mode) # clamp to [-1,1] score = max(-1.0, min(1.0, score)) conf = max(pos, neg, neu) label = "positive" if score > 0 else ("negative" if score < 0 else "neutral") # Optional gating if conf < min_conf or abs(score) < neutral_zone: label = "neutral" return { "label": label, "score": score, "confidence": conf, "scores": {"positive": pos, "neutral": neu, "negative": neg}, } @app.post("/predict") def predict( payload: Payload, mode: Literal["raw","debias","logit"] = Query("raw"), min_conf: float = Query(0.60, ge=0.0, le=1.0), neutral_zone: float = Query(0.20, ge=0.0, le=1.0) ): """ - Use top_k=None (replacement for deprecated return_all_scores=True) - Force truncation/padding/max_length to avoid 631>514 crashes """ clf = load_pipeline() texts = payload.sentences or [] outs = clf( texts, top_k=None, # replaces return_all_scores=True truncation=True, # <-- important for long inputs padding=True, max_length=512 ) # If a single string was passed, HF may return a single item; normalize to list if isinstance(outs, dict) or (outs and isinstance(outs[0], dict)): outs = [outs] # ensure list[list[dict]] binary_hint = (loaded_model_id == FALLBACK_MODEL) results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs] return {"model": loaded_model_id, "results": results}