sentiment-eval / app.py
CadenShokat's picture
Update app.py
bef81d1 verified
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import Literal
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import math
PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
app = FastAPI()
clf = None
loaded_model_id = None
class Payload(BaseModel):
sentences: list[str]
@app.get("/healthz")
def health():
return {"status": "healthy", "model": loaded_model_id}
def load_pipeline():
"""
Robust loader:
- Avoid meta-tensor issue by forcing low_cpu_mem_usage=False
- No device_map; keep on CPU
- Cache the pipeline in the global `clf`
"""
global clf, loaded_model_id
if clf is not None:
return clf
for model_id in (PRIMARY_MODEL, FALLBACK_MODEL):
try:
tok = AutoTokenizer.from_pretrained(model_id) # use_fast default is fine
mdl = AutoModelForSequenceClassification.from_pretrained(
model_id,
low_cpu_mem_usage=False, # <-- important to avoid meta tensors
trust_remote_code=False
)
loaded_model_id = model_id
clf = pipeline(
"text-classification",
model=mdl,
tokenizer=tok,
device=-1 # CPU
)
return clf
except Exception as e:
print(f"Failed to load {model_id}: {e}")
raise RuntimeError("No sentiment model could be loaded")
def compute_score(pos: float, neg: float, neu: float, mode: str) -> float:
if mode == "debias": # original
denom = max(1e-6, 1.0 - neu)
return (pos - neg) / denom
elif mode == "raw": # conservative
return pos - neg
elif mode == "logit": # optional slightly squashed
# difference of logits, then tanh to clamp to [-1,1]
import math as _m
lp = _m.log(max(1e-9, pos)) - _m.log(max(1e-9, 1-pos))
ln = _m.log(max(1e-9, neg)) - _m.log(max(1e-9, 1-neg))
return _m.tanh((lp - ln) / 4.0)
else:
return pos - neg
def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float):
m = {s["label"].lower(): float(s["score"]) for s in scores}
keys = set(m.keys())
neg = neu = pos = 0.0
detected_binary = False
if {"negative","positive"} & keys:
neg, pos = m.get("negative",0.0), m.get("positive",0.0)
neu = m.get("neutral",0.0)
detected_binary = ("neutral" not in m) or (len(m) == 2)
elif any(k.startswith("label_") for k in keys):
neg = m.get("label_0", 0.0)
if "label_2" in m or len(m) >= 3:
neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False
else:
pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True
else:
for k,v in m.items():
if "pos" in k: pos = v
if "neg" in k: neg = v
if "neu" in k: neu = v
detected_binary = (neu == 0.0)
is_binary = detected_binary if binary_hint is None else bool(binary_hint)
if is_binary:
neu = 0.0
score = compute_score(pos, neg, neu, mode)
# clamp to [-1,1]
score = max(-1.0, min(1.0, score))
conf = max(pos, neg, neu)
label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
# Optional gating
if conf < min_conf or abs(score) < neutral_zone:
label = "neutral"
return {
"label": label,
"score": score,
"confidence": conf,
"scores": {"positive": pos, "neutral": neu, "negative": neg},
}
@app.post("/predict")
def predict(
payload: Payload,
mode: Literal["raw","debias","logit"] = Query("raw"),
min_conf: float = Query(0.60, ge=0.0, le=1.0),
neutral_zone: float = Query(0.20, ge=0.0, le=1.0)
):
"""
- Use top_k=None (replacement for deprecated return_all_scores=True)
- Force truncation/padding/max_length to avoid 631>514 crashes
"""
clf = load_pipeline()
texts = payload.sentences or []
outs = clf(
texts,
top_k=None, # replaces return_all_scores=True
truncation=True, # <-- important for long inputs
padding=True,
max_length=512
)
# If a single string was passed, HF may return a single item; normalize to list
if isinstance(outs, dict) or (outs and isinstance(outs[0], dict)):
outs = [outs] # ensure list[list[dict]]
binary_hint = (loaded_model_id == FALLBACK_MODEL)
results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs]
return {"model": loaded_model_id, "results": results}