CadenShokat commited on
Commit
87a2111
·
verified ·
1 Parent(s): 5ed102d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -54
app.py CHANGED
@@ -1,25 +1,22 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
 
4
 
5
- PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment" # stable 3-class
6
- FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" # binary
7
 
8
  app = FastAPI()
9
  clf = None
10
  loaded_model_id = None
11
 
12
- @app.get("/")
13
- def root():
14
- return {"status": "ok", "model": loaded_model_id}
15
 
16
  @app.get("/healthz")
17
  def health():
18
  return {"status": "healthy", "model": loaded_model_id}
19
 
20
- class Payload(BaseModel):
21
- sentences: list[str]
22
-
23
  def load_pipeline():
24
  global clf, loaded_model_id
25
  if clf is not None:
@@ -34,70 +31,76 @@ def load_pipeline():
34
  print(f"Failed to load {model_id}: {e}")
35
  raise RuntimeError("No sentiment model could be loaded")
36
 
37
- def scores_to_label(scores, binary=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  m = {s["label"].lower(): float(s["score"]) for s in scores}
39
  keys = set(m.keys())
40
 
41
  neg = neu = pos = 0.0
42
  detected_binary = False
43
 
44
- if {"negative", "positive"} & keys:
45
- # Named labels present; detect if neutral exists
46
- neg = m.get("negative", 0.0)
47
- pos = m.get("positive", 0.0)
48
- neu = m.get("neutral", 0.0)
49
  detected_binary = ("neutral" not in m) or (len(m) == 2)
50
  elif any(k.startswith("label_") for k in keys):
51
  neg = m.get("label_0", 0.0)
52
  if "label_2" in m or len(m) >= 3:
53
- neu = m.get("label_1", 0.0)
54
- pos = m.get("label_2", 0.0)
55
- detected_binary = False
56
  else:
57
- pos = m.get("label_1", 0.0)
58
- neu = 0.0
59
- detected_binary = True
60
  else:
61
- for k, v in m.items():
62
  if "pos" in k: pos = v
63
  if "neg" in k: neg = v
64
  if "neu" in k: neu = v
65
  detected_binary = (neu == 0.0)
66
 
67
- is_binary = detected_binary if binary is None else bool(binary)
 
 
68
 
69
- if is_binary:
70
- score = pos - neg
71
- conf = max(pos, neg)
72
- label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
73
- return {
74
- "label": label,
75
- "score": max(-1.0, min(1.0, score)),
76
- "confidence": conf,
77
- "scores": {"positive": pos, "neutral": 0.0, "negative": neg},
78
- }
79
- else:
80
- denom = max(1e-6, 1.0 - neu)
81
- score = (pos - neg) / denom
82
- score = max(-1.0, min(1.0, score))
83
- conf = max(pos, neg, neu)
84
- label = "positive" if pos >= max(neg, neu) else ("negative" if neg >= max(pos, neu) else "neutral")
85
- if conf < 0.55:
86
- label = "neutral"
87
- return {
88
- "label": label,
89
- "score": score,
90
- "confidence": conf,
91
- "scores": {"positive": pos, "neutral": neu, "negative": neg},
92
- }
93
 
 
 
 
 
 
 
94
 
95
  @app.post("/predict")
96
- def predict(payload: Payload):
97
- classifier = load_pipeline()
98
- if not payload.sentences:
99
- return {"model": loaded_model_id, "results": []}
100
- outputs = classifier(payload.sentences, top_k=None)
101
- binary = (loaded_model_id == FALLBACK_MODEL)
102
- results = [scores_to_label(scores, binary=binary) for scores in outputs]
 
 
 
 
103
  return {"model": loaded_model_id, "results": results}
 
1
+ from fastapi import FastAPI, Query
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
4
+ import math
5
 
6
+ PRIMARY_MODEL = "cardiffnlp/twitter-roberta-base-sentiment"
7
+ FALLBACK_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"
8
 
9
  app = FastAPI()
10
  clf = None
11
  loaded_model_id = None
12
 
13
+ class Payload(BaseModel):
14
+ sentences: list[str]
 
15
 
16
  @app.get("/healthz")
17
  def health():
18
  return {"status": "healthy", "model": loaded_model_id}
19
 
 
 
 
20
  def load_pipeline():
21
  global clf, loaded_model_id
22
  if clf is not None:
 
31
  print(f"Failed to load {model_id}: {e}")
32
  raise RuntimeError("No sentiment model could be loaded")
33
 
34
+ def compute_score(pos: float, neg: float, neu: float, mode: str) -> float:
35
+ if mode == "debias": # original
36
+ denom = max(1e-6, 1.0 - neu)
37
+ return (pos - neg) / denom
38
+ elif mode == "raw": # conservative
39
+ return pos - neg
40
+ elif mode == "logit": # optional slightly squashed
41
+ # difference of logits, then tanh to clamp to [-1,1]
42
+ lp = math.log(max(1e-9, pos)) - math.log(max(1e-9, 1-pos))
43
+ ln = math.log(max(1e-9, neg)) - math.log(max(1e-9, 1-neg))
44
+ return math.tanh((lp - ln) / 4.0)
45
+ else:
46
+ return pos - neg
47
+
48
+ def scores_to_label(scores, mode: str, binary_hint: bool | None, min_conf: float, neutral_zone: float):
49
  m = {s["label"].lower(): float(s["score"]) for s in scores}
50
  keys = set(m.keys())
51
 
52
  neg = neu = pos = 0.0
53
  detected_binary = False
54
 
55
+ if {"negative","positive"} & keys:
56
+ neg, pos = m.get("negative",0.0), m.get("positive",0.0)
57
+ neu = m.get("neutral",0.0)
 
 
58
  detected_binary = ("neutral" not in m) or (len(m) == 2)
59
  elif any(k.startswith("label_") for k in keys):
60
  neg = m.get("label_0", 0.0)
61
  if "label_2" in m or len(m) >= 3:
62
+ neu = m.get("label_1", 0.0); pos = m.get("label_2", 0.0); detected_binary = False
 
 
63
  else:
64
+ pos = m.get("label_1", 0.0); neu = 0.0; detected_binary = True
 
 
65
  else:
66
+ for k,v in m.items():
67
  if "pos" in k: pos = v
68
  if "neg" in k: neg = v
69
  if "neu" in k: neu = v
70
  detected_binary = (neu == 0.0)
71
 
72
+ is_binary = detected_binary if binary_hint is None else bool(binary_hint)
73
+ if is_binary:
74
+ neu = 0.0
75
 
76
+ score = compute_score(pos, neg, neu, mode)
77
+ # clamp to [-1,1]
78
+ score = max(-1.0, min(1.0, score))
79
+
80
+ conf = max(pos, neg, neu)
81
+ label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
82
+
83
+ # Optional gating
84
+ if conf < min_conf or abs(score) < neutral_zone:
85
+ label = "neutral"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ return {
88
+ "label": label,
89
+ "score": score,
90
+ "confidence": conf,
91
+ "scores": {"positive": pos, "neutral": neu, "negative": neg},
92
+ }
93
 
94
  @app.post("/predict")
95
+ def predict(
96
+ payload: Payload,
97
+ mode: str = Query("raw", pattern="^(raw|debias|logit)$"),
98
+ min_conf: float = Query(0.60, ge=0.0, le=1.0),
99
+ neutral_zone: float = Query(0.20, ge=0.0, le=1.0)
100
+ ):
101
+ clf = load_pipeline()
102
+ texts = payload.sentences or []
103
+ outs = clf(texts, top_k=None)
104
+ binary_hint = (loaded_model_id == FALLBACK_MODEL)
105
+ results = [scores_to_label(s, mode, binary_hint, min_conf, neutral_zone) for s in outs]
106
  return {"model": loaded_model_id, "results": results}