sentiment-eval / app.py
CadenShokat's picture
Update app.py
bef81d1 verified
raw
history blame
4.77 kB
from fastapi import FastAPI, Query
from pydantic import BaseModel
from typing import Literal
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import math
PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
app = FastAPI()
clf = None
loaded_model_id = None
class Payload(BaseModel):
sentences: list[str]
@app.get("/healthz")
def health():
return {"status": "healthy", "model": loaded_model_id}
def load_pipeline():
"""
Robust loader:
- Avoid meta-tensor issue by forcing low_cpu_mem_usage=False
- No device_map; keep on CPU
- Cache the pipeline in the global `clf`
"""
global clf, loaded_model_id
if clf is not None:
return clf
for model_id in (PRIMARY_MODEL, FALLBACK_MODEL):
try:
tok = AutoTokenizer.from_pretrained(model_id) # use_fast default is fine
mdl = AutoModelForSequenceClassification.from_pretrained(
model_id,
low_cpu_mem_usage=False, # <-- important to avoid meta tensors
trust_remote_code=False
)
loaded_model_id = model_id
clf = pipeline(
"text-classification",
model=mdl,
tokenizer=tok,
device=-1 # CPU
)
return clf
except Exception as e:
print(f"Failed to load {model_id}: {e}")
raise RuntimeError("No sentiment model could be loaded")
def compute_score(pos: float, neg: float, neu: float, mode: str) -> float:
if mode == "debias": # original
denom = max(1e-6, 1.0 - neu)
return (pos - neg) / denom
elif mode == "raw": # conservative
return pos - neg
elif mode == "logit": # optional slightly squashed
# difference of logits, then tanh to clamp to [-1,1]
import math as _m
lp = _m.log(max(1e-9, pos)) - _m.log(max(1e-9, 1-pos))
ln = _m.log(max(1e-9, neg)) - _m.log(max(1e-9, 1-neg))
return _m.tanh((lp - ln) / 4.0)
else:
return pos - neg
def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float):
m = {s["label"].lower(): float(s["score"]) for s in scores}
keys = set(m.keys())
neg = neu = pos = 0.0
detected_binary = False
if {"negative","positive"} & keys:
neg, pos = m.get("negative",0.0), m.get("positive",0.0)
neu = m.get("neutral",0.0)
detected_binary = ("neutral" not in m) or (len(m) == 2)
elif any(k.startswith("label_") for k in keys):
neg = m.get("label_0", 0.0)
if "label_2" in m or len(m) >= 3:
neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False
else:
pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True
else:
for k,v in m.items():
if "pos" in k: pos = v
if "neg" in k: neg = v
if "neu" in k: neu = v
detected_binary = (neu == 0.0)
is_binary = detected_binary if binary_hint is None else bool(binary_hint)
if is_binary:
neu = 0.0
score = compute_score(pos, neg, neu, mode)
# clamp to [-1,1]
score = max(-1.0, min(1.0, score))
conf = max(pos, neg, neu)
label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
# Optional gating
if conf < min_conf or abs(score) < neutral_zone:
label = "neutral"
return {
"label": label,
"score": score,
"confidence": conf,
"scores": {"positive": pos, "neutral": neu, "negative": neg},
}
@app.post("/predict")
def predict(
payload: Payload,
mode: Literal["raw","debias","logit"] = Query("raw"),
min_conf: float = Query(0.60, ge=0.0, le=1.0),
neutral_zone: float = Query(0.20, ge=0.0, le=1.0)
):
"""
- Use top_k=None (replacement for deprecated return_all_scores=True)
- Force truncation/padding/max_length to avoid 631>514 crashes
"""
clf = load_pipeline()
texts = payload.sentences or []
outs = clf(
texts,
top_k=None, # replaces return_all_scores=True
truncation=True, # <-- important for long inputs
padding=True,
max_length=512
)
# If a single string was passed, HF may return a single item; normalize to list
if isinstance(outs, dict) or (outs and isinstance(outs[0], dict)):
outs = [outs] # ensure list[list[dict]]
binary_hint = (loaded_model_id == FALLBACK_MODEL)
results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs]
return {"model": loaded_model_id, "results": results}