CadenShokat commited on
Commit
5ed102d
·
verified ·
1 Parent(s): a3803ab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -13
app.py CHANGED
@@ -34,26 +34,63 @@ def load_pipeline():
34
  print(f"Failed to load {model_id}: {e}")
35
  raise RuntimeError("No sentiment model could be loaded")
36
 
37
- def scores_to_label(scores, binary=False):
38
- m = {s["label"].lower(): s["score"] for s in scores}
39
- if binary or ("neutral" not in m):
40
- neg, pos = m.get("negative", 0.0), m.get("positive", 0.0)
41
- score = pos - neg # [-1,1]
42
- conf = max(neg, pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
44
- return {"label": label, "score": max(-1.0, min(1.0, score)), "confidence": conf,
45
- "scores": {"positive": pos, "neutral": 0.0, "negative": neg}}
 
 
 
 
46
  else:
47
- neg, neu, pos = m.get("negative", 0.0), m.get("neutral", 0.0), m.get("positive", 0.0)
48
  denom = max(1e-6, 1.0 - neu)
49
- score = (pos - neg) / denom
50
  score = max(-1.0, min(1.0, score))
51
- conf = max(neg, neu, pos)
52
  label = "positive" if pos >= max(neg, neu) else ("negative" if neg >= max(pos, neu) else "neutral")
53
  if conf < 0.55:
54
  label = "neutral"
55
- return {"label": label, "score": score, "confidence": conf,
56
- "scores": {"positive": pos, "neutral": neu, "negative": neg}}
 
 
 
 
 
57
 
58
  @app.post("/predict")
59
  def predict(payload: Payload):
 
34
  print(f"Failed to load {model_id}: {e}")
35
  raise RuntimeError("No sentiment model could be loaded")
36
 
37
+ def scores_to_label(scores, binary=None):
38
+ m = {s["label"].lower(): float(s["score"]) for s in scores}
39
+ keys = set(m.keys())
40
+
41
+ neg = neu = pos = 0.0
42
+ detected_binary = False
43
+
44
+ if {"negative", "positive"} & keys:
45
+ # Named labels present; detect if neutral exists
46
+ neg = m.get("negative", 0.0)
47
+ pos = m.get("positive", 0.0)
48
+ neu = m.get("neutral", 0.0)
49
+ detected_binary = ("neutral" not in m) or (len(m) == 2)
50
+ elif any(k.startswith("label_") for k in keys):
51
+ neg = m.get("label_0", 0.0)
52
+ if "label_2" in m or len(m) >= 3:
53
+ neu = m.get("label_1", 0.0)
54
+ pos = m.get("label_2", 0.0)
55
+ detected_binary = False
56
+ else:
57
+ pos = m.get("label_1", 0.0)
58
+ neu = 0.0
59
+ detected_binary = True
60
+ else:
61
+ for k, v in m.items():
62
+ if "pos" in k: pos = v
63
+ if "neg" in k: neg = v
64
+ if "neu" in k: neu = v
65
+ detected_binary = (neu == 0.0)
66
+
67
+ is_binary = detected_binary if binary is None else bool(binary)
68
+
69
+ if is_binary:
70
+ score = pos - neg
71
+ conf = max(pos, neg)
72
  label = "positive" if score > 0 else ("negative" if score < 0 else "neutral")
73
+ return {
74
+ "label": label,
75
+ "score": max(-1.0, min(1.0, score)),
76
+ "confidence": conf,
77
+ "scores": {"positive": pos, "neutral": 0.0, "negative": neg},
78
+ }
79
  else:
 
80
  denom = max(1e-6, 1.0 - neu)
81
+ score = (pos - neg) / denom
82
  score = max(-1.0, min(1.0, score))
83
+ conf = max(pos, neg, neu)
84
  label = "positive" if pos >= max(neg, neu) else ("negative" if neg >= max(pos, neu) else "neutral")
85
  if conf < 0.55:
86
  label = "neutral"
87
+ return {
88
+ "label": label,
89
+ "score": score,
90
+ "confidence": conf,
91
+ "scores": {"positive": pos, "neutral": neu, "negative": neg},
92
+ }
93
+
94
 
95
  @app.post("/predict")
96
  def predict(payload: Payload):