sentiment-eval / app.py
CadenShokat's picture
Update app.py
87a2111 verified
raw
history blame
3.78 kB
from fastapi import FastAPI, Query
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import math
PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
app = FastAPI()
clf = None
loaded_model_id = None
class Payload(BaseModel):
sentences: list[str]
@app.get("/healthz")
def health():
return {"status": "healthy", "model": loaded_model_id}
def load_pipeline():
global clf, loaded_model_id
if clf is not None:
return clf
for model_id in (PRIMARY_MODEL, FALLBACK_MODEL):
try:
tok = AutoTokenizer.from_pretrained(model_id)
mdl = AutoModelForSequenceClassification.from_pretrained(model_id)
loaded_model_id = model_id
return pipeline("text-classification", model=mdl, tokenizer=tok, return_all_scores=True, truncation=True)
except Exception as e:
print(f"Failed to load {model_id}: {e}")
raise RuntimeError("No sentiment model could be loaded")
def compute_score(pos: float, neg: float, neu: float, mode: str) -> float:
if mode == "debias": # original
denom = max(1e-6, 1.0 - neu)
return (pos - neg) / denom
elif mode == "raw": # conservative
return pos - neg
elif mode == "logit": # optional slightly squashed
# difference of logits, then tanh to clamp to [-1,1]
lp = math.log(max(1e-9, pos)) - math.log(max(1e-9, 1-pos))
ln = math.log(max(1e-9, neg)) - math.log(max(1e-9, 1-neg))
return math.tanh((lp - ln) / 4.0)
else:
return pos - neg
def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float):
m = {s["label"].lower(): float(s["score"]) for s in scores}
keys = set(m.keys())
neg = neu = pos = 0.0
detected_binary = False
if {"negative","positive"} & keys:
neg, pos = m.get("negative",0.0), m.get("positive",0.0)
neu = m.get("neutral",0.0)
detected_binary = ("neutral" not in m) or (len(m) == 2)
elif any(k.startswith("label_") for k in keys):
neg = m.get("label_0", 0.0)
if "label_2" in m or len(m) >= 3:
neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False
else:
pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True
else:
for k,v in m.items():
if "pos" in k: pos = v
if "neg" in k: neg = v
if "neu" in k: neu = v
detected_binary = (neu == 0.0)
is_binary = detected_binary if binary_hint is None else bool(binary_hint)
if is_binary:
neu = 0.0
score = compute_score(pos, neg, neu, mode)
# clamp to [-1,1]
score = max(-1.0, min(1.0, score))
conf = max(pos, neg, neu)
label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
# Optional gating
if conf < min_conf or abs(score) < neutral_zone:
label = "neutral"
return {
"label": label,
"score": score,
"confidence": conf,
"scores": {"positive": pos, "neutral": neu, "negative": neg},
}
@app.post("/predict")
def predict(
payload: Payload,
mode: str = Query("raw", pattern="^(raw|debias|logit)$"),
min_conf: float = Query(0.60, ge=0.0, le=1.0),
neutral_zone: float = Query(0.20, ge=0.0, le=1.0)
):
clf = load_pipeline()
texts = payload.sentences or []
outs = clf(texts, top_k=None)
binary_hint = (loaded_model_id == FALLBACK_MODEL)
results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs]
return {"model": loaded_model_id, "results": results}