File size: 3,777 Bytes
87a2111
93f8930
 
87a2111
93f8930
87a2111
 
93f8930
 
020a076
 
93f8930
87a2111
 
93f8930
 
020a076
 
93f8930
020a076
 
 
 
 
 
 
 
 
2aeae98
020a076
 
 
 
87a2111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed102d
 
 
 
 
 
87a2111
 
 
5ed102d
 
 
 
87a2111
5ed102d
87a2111
5ed102d
87a2111
5ed102d
 
 
 
 
87a2111
 
 
5ed102d
87a2111
 
 
 
 
 
 
 
 
 
5ed102d
87a2111
 
 
 
 
 
93f8930
 
87a2111
 
 
 
 
 
 
 
 
 
 
020a076
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, Query
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import math

PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"

app = FastAPI()
clf = None
loaded_model_id = None

class Payload(BaseModel):
    sentences: list[str]

@app.get("/healthz")
def health():
    return {"status": "healthy", "model": loaded_model_id}

def load_pipeline():
    global clf, loaded_model_id
    if clf is not None:
        return clf
    for model_id in (PRIMARY_MODEL, FALLBACK_MODEL):
        try:
            tok = AutoTokenizer.from_pretrained(model_id)
            mdl = AutoModelForSequenceClassification.from_pretrained(model_id)
            loaded_model_id = model_id
            return pipeline("text-classification", model=mdl, tokenizer=tok, return_all_scores=True, truncation=True)
        except Exception as e:
            print(f"Failed to load {model_id}: {e}")
    raise RuntimeError("No sentiment model could be loaded")

def compute_score(pos: float, neg: float, neu: float, mode: str) -> float:
    if mode == "debias":  # original
        denom = max(1e-6, 1.0 - neu)
        return (pos - neg) / denom
    elif mode == "raw":   # conservative
        return pos - neg
    elif mode == "logit": # optional slightly squashed
        # difference of logits, then tanh to clamp to [-1,1]
        lp = math.log(max(1e-9, pos)) - math.log(max(1e-9, 1-pos))
        ln = math.log(max(1e-9, neg)) - math.log(max(1e-9, 1-neg))
        return math.tanh((lp - ln) / 4.0)
    else:
        return pos - neg

def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float):
    m = {s["label"].lower(): float(s["score"]) for s in scores}
    keys = set(m.keys())

    neg = neu = pos = 0.0
    detected_binary = False

    if {"negative","positive"} & keys:
        neg, pos = m.get("negative",0.0), m.get("positive",0.0)
        neu = m.get("neutral",0.0)
        detected_binary = ("neutral" not in m) or (len(m) == 2)
    elif any(k.startswith("label_") for k in keys):
        neg = m.get("label_0", 0.0)
        if "label_2" in m or len(m) >= 3:
            neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False
        else:
            pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True
    else:
        for k,v in m.items():
            if "pos" in k: pos = v
            if "neg" in k: neg = v
            if "neu" in k: neu = v
        detected_binary = (neu == 0.0)

    is_binary = detected_binary if binary_hint is None else bool(binary_hint)
    if is_binary:  
        neu = 0.0

    score = compute_score(pos, neg, neu, mode)
    # clamp to [-1,1]
    score = max(-1.0, min(1.0, score))

    conf = max(pos, neg, neu)
    label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")

    # Optional gating
    if conf < min_conf or abs(score) < neutral_zone:
        label = "neutral"

    return {
        "label": label,
        "score": score,
        "confidence": conf,
        "scores": {"positive": pos, "neutral": neu, "negative": neg},
    }

@app.post("/predict")
def predict(
    payload: Payload,
    mode: str = Query("raw", pattern="^(raw|debias|logit)$"),
    min_conf: float = Query(0.60, ge=0.0, le=1.0),              
    neutral_zone: float = Query(0.20, ge=0.0, le=1.0)          
):
    clf = load_pipeline()
    texts = payload.sentences or []
    outs = clf(texts, top_k=None)
    binary_hint = (loaded_model_id == FALLBACK_MODEL)
    results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs]
    return {"model": loaded_model_id, "results": results}