ai-assist-sh commited on
Commit
e2d3e54
·
verified ·
1 Parent(s): 2f88b82

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +94 -50
main.py CHANGED
@@ -4,7 +4,8 @@ from typing import List, Dict, Tuple
4
 
5
  import gradio as gr
6
 
7
- # Optional imports for email classifier (loaded lazily)
 
8
  try:
9
  import torch
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
@@ -16,16 +17,16 @@ except Exception:
16
  # =========================
17
  # Config (env-overridable)
18
  # =========================
19
- EMAIL_CLASSIFIER_ID = os.getenv("EMAIL_CLASSIFIER_ID", "your-username/mini-phish") # swap to your HF model repo later
20
- EMAIL_BACKBONE_ID = os.getenv("EMAIL_BACKBONE_ID", "microsoft/MiniLM-L6-H384-uncased")
21
- THRESHOLD_TAU = float(os.getenv("THRESHOLD_TAU", "0.40"))
22
- MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", "320"))
23
- SUBJECT_TOKEN_BUDGET= int(os.getenv("SUBJECT_TOKEN_BUDGET", "64"))
24
- FUSION_EMAIL_W = float(os.getenv("FUSION_EMAIL_W", "0.6"))
25
- FUSION_URL_W = float(os.getenv("FUSION_URL_W", "0.4"))
26
- URL_OVERRIDE_HIGH = float(os.getenv("URL_OVERRIDE_HIGH", "0.85"))
27
- URL_OVERRIDE_KW = float(os.getenv("URL_OVERRIDE_KW", "0.70"))
28
- ALLOWLIST_SAFE_CAP = float(os.getenv("ALLOWLIST_SAFE_CAP", "0.15"))
29
 
30
  # =========================
31
  # Simple data classes
@@ -40,21 +41,22 @@ class UrlResult:
40
  class EmailResult:
41
  p_email: float
42
  kw_hits: List[str]
 
43
 
44
  # =========================
45
- # URL extraction & heuristics (replace with your existing pipeline)
46
  # =========================
47
  URL_REGEX = r'(?i)\b((?:https?://|www\.)[^\s<>")]+)'
48
 
49
- SUSPICIOUS_TLDS = {".xyz", ".top", ".click", ".link", ".ru", ".cn", ".country", ".gq", ".ga", ".ml", ".tk"}
 
 
50
  SHORTENERS = {"bit.ly","t.co","tinyurl.com","goo.gl","ow.ly","is.gd","cutt.ly","tiny.one","lnkd.in"}
51
 
52
  def extract_urls(text: str) -> List[str]:
53
  if not text: return []
54
  urls = re.findall(URL_REGEX, text)
55
- # normalize
56
- uniq = []
57
- seen = set()
58
  for u in urls:
59
  u = u.strip().strip(').,;\'"')
60
  if u and u not in seen:
@@ -95,29 +97,41 @@ def score_urls(urls: List[str]) -> List[UrlResult]:
95
  _tokenizer = None
96
  _model = None
97
 
98
- LEXICAL_CUES = [
99
- "verify your account","update your password","immediately","within 24 hours",
100
- "suspended","unusual activity","confirm","login","click","invoice","payment",
101
- "otp","one-time password","unlock","reactivate","restricted","authenticate",
102
- "security alert","urgent","limited time"
 
 
103
  ]
104
 
 
 
 
 
 
 
 
 
 
105
  def load_email_model() -> Tuple[object, object]:
 
 
106
  global _tokenizer, _model
107
  if _tokenizer is not None and _model is not None:
108
  return _tokenizer, _model
109
 
110
  if AutoTokenizer is None or AutoModelForSequenceClassification is None or torch is None:
111
- # environment without torch/transformers (Space will still boot)
112
- return None, None
113
 
114
- # Try the preferred classifier first
115
  model_id = EMAIL_CLASSIFIER_ID
116
  try:
117
  _tokenizer = AutoTokenizer.from_pretrained(model_id)
118
  _model = AutoModelForSequenceClassification.from_pretrained(model_id)
119
  except Exception:
120
- # Fallback: load backbone and attach a tiny random head
121
  try:
122
  _tokenizer = AutoTokenizer.from_pretrained(EMAIL_BACKBONE_ID)
123
  _model = AutoModelForSequenceClassification.from_pretrained(
@@ -127,7 +141,7 @@ def load_email_model() -> Tuple[object, object]:
127
  _tokenizer, _model = None, None
128
  return None, None
129
 
130
- # Dynamic quantization for CPU
131
  try:
132
  _model.eval()
133
  _model.to("cpu")
@@ -146,62 +160,78 @@ def _truncate_for_budget(tokens_subject: List[int], tokens_body: List[int], max_
146
  return subj + body
147
 
148
  def score_email(subject: str, body: str) -> EmailResult:
 
 
149
  text = (subject or "") + "\n" + (body or "")
150
- # lightweight lexical cues for reasons + kw_flag
151
- hits = [c for c in LEXICAL_CUES if c in text.lower()]
 
 
 
152
 
153
  tok, mdl = load_email_model()
154
  if tok is None or mdl is None:
155
- # fallback purely lexical probability
156
- base = 0.15 + 0.1 * len(hits)
157
- return EmailResult(p_email=float(min(base, 0.99)), kw_hits=hits)
 
 
158
 
159
- # tokenize with budget
160
  encoded_subj = tok.encode(subject or "", add_special_tokens=False)
161
  encoded_body = tok.encode(body or "", add_special_tokens=False)
162
- input_ids = _truncate_for_budget(encoded_subj, encoded_body, MAX_SEQ_LEN-2, SUBJECT_TOKEN_BUDGET)
163
  input_ids = [tok.cls_token_id] + input_ids + [tok.sep_token_id]
164
- attn_mask = [1]*len(input_ids)
165
 
166
- import torch
167
  ids = torch.tensor([input_ids], dtype=torch.long)
168
  mask = torch.tensor([attn_mask], dtype=torch.long)
 
169
  with torch.no_grad():
170
  out = mdl(input_ids=ids, attention_mask=mask)
171
  if hasattr(out, "logits"):
172
- logits = out.logits[0].detach().cpu().numpy().tolist()
173
- # assume label 1 = phishing (prob via softmax)
174
  import math
 
175
  exps = [math.exp(x) for x in logits]
176
- p1 = exps[1] / (exps[0] + exps[1])
177
  p_email = float(p1)
178
  else:
179
  p_email = 0.5
180
 
181
- # small calibration nudge from lexical cues (kept light)
182
- p_email = float(min(0.99, max(0.01, p_email + 0.03*len(hits))))
183
- return EmailResult(p_email=p_email, kw_hits=hits)
 
 
184
 
185
  # =========================
186
  # Fusion
187
  # =========================
188
  def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains: List[str]) -> Dict:
189
  r_url_max = max([u.risk for u in url_results], default=0.0)
190
- kw_flag = 1 if email_res.kw_hits else 0
191
 
192
- # Allowlist check: if any URL host in allowlist
193
  allowlist_hit = False
194
  for u in url_results:
195
  h = url_host(u.url)
196
- if any(h.endswith(d.lower()) for d in allowlist_domains):
197
  allowlist_hit = True
198
  break
199
 
 
200
  r_total = FUSION_EMAIL_W * email_res.p_email + FUSION_URL_W * r_url_max
201
 
 
 
202
  if (r_url_max >= URL_OVERRIDE_HIGH) or (kw_flag and r_url_max >= URL_OVERRIDE_KW):
203
  r_total = max(r_total, 0.90)
204
 
 
 
 
 
 
205
  if allowlist_hit:
206
  r_total = min(r_total, ALLOWLIST_SAFE_CAP)
207
 
@@ -211,6 +241,8 @@ def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains
211
  "R_url_max": round(r_url_max, 3),
212
  "R_total": round(r_total, 3),
213
  "kw_hits": email_res.kw_hits,
 
 
214
  "allowlist_hit": allowlist_hit,
215
  "verdict": verdict
216
  }
@@ -219,17 +251,19 @@ def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains
219
  # Gradio UI
220
  # =========================
221
  with gr.Blocks(title="PhishingMail-Lab") as demo:
222
- gr.Markdown("# 🧪 PhishingMail‑Lab\nFree‑tier friendly POC with email+URL fusion")
223
 
224
  with gr.Row():
225
  subject = gr.Textbox(label="Subject", placeholder="Subject: Important account update")
226
- body = gr.Textbox(label="Email Body (paste text or HTML)", lines=10, placeholder="Paste the email content here...")
227
  with gr.Row():
228
- allowlist = gr.Textbox(label="Allowlist domains (comma-separated)", placeholder="microsoft.com, amazon.com")
229
  tau = gr.Slider(0, 1, value=THRESHOLD_TAU, step=0.01, label="Decision Threshold τ")
230
  analyze_btn = gr.Button("Analyze")
231
 
232
  verdict = gr.Label(label="Verdict")
 
 
233
  fusion_json = gr.JSON(label="Fusion & Flags")
234
  url_table = gr.Dataframe(headers=["URL","Risk","Reasons"], label="Per‑URL risk (heuristics demo)", interactive=False)
235
 
@@ -237,17 +271,27 @@ with gr.Blocks(title="PhishingMail-Lab") as demo:
237
  global THRESHOLD_TAU
238
  THRESHOLD_TAU = float(tau_val)
239
 
240
- urls = list(dict.fromkeys(extract_urls((subject_text or "") + "\n" + (body_text or "")))) # uniq while preserving order
 
241
  url_results = score_urls(urls)
242
  allow_domains = [d.strip().lower() for d in (allowlist_text or "").split(",") if d.strip()]
243
 
244
  email_res = score_email(subject_text or "", body_text or "")
245
  fused = fuse(email_res, url_results, allow_domains)
246
 
 
 
 
 
 
 
 
 
 
247
  rows = [[u.url, round(u.risk,3), ", ".join(u.reasons)] for u in url_results]
248
- return fused["verdict"], fused, rows
249
 
250
- analyze_btn.click(run, [subject, body, allowlist, tau], [verdict, fusion_json, url_table])
251
 
252
  if __name__ == "__main__":
253
  demo.launch()
 
4
 
5
  import gradio as gr
6
 
7
+ # Optional imports for email classifier (loaded lazily).
8
+ # Space still runs if these aren't available (pure lexical fallback).
9
  try:
10
  import torch
11
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
17
  # =========================
18
  # Config (env-overridable)
19
  # =========================
20
+ EMAIL_CLASSIFIER_ID = os.getenv("EMAIL_CLASSIFIER_ID", "your-username/mini-phish") # <- swap to your HF repo when ready
21
+ EMAIL_BACKBONE_ID = os.getenv("EMAIL_BACKBONE_ID", "microsoft/MiniLM-L6-H384-uncased")
22
+ THRESHOLD_TAU = float(os.getenv("THRESHOLD_TAU", "0.40"))
23
+ MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", "320"))
24
+ SUBJECT_TOKEN_BUDGET = int(os.getenv("SUBJECT_TOKEN_BUDGET", "64"))
25
+ FUSION_EMAIL_W = float(os.getenv("FUSION_EMAIL_W", "0.6"))
26
+ FUSION_URL_W = float(os.getenv("FUSION_URL_W", "0.4"))
27
+ URL_OVERRIDE_HIGH = float(os.getenv("URL_OVERRIDE_HIGH", "0.85"))
28
+ URL_OVERRIDE_KW = float(os.getenv("URL_OVERRIDE_KW", "0.70"))
29
+ ALLOWLIST_SAFE_CAP = float(os.getenv("ALLOWLIST_SAFE_CAP", "0.15"))
30
 
31
  # =========================
32
  # Simple data classes
 
41
  class EmailResult:
42
  p_email: float
43
  kw_hits: List[str]
44
+ strong_hits: List[str] # subset of kw_hits considered strong
45
 
46
  # =========================
47
+ # URL extraction & heuristics (swap with your real URL model when ready)
48
  # =========================
49
  URL_REGEX = r'(?i)\b((?:https?://|www\.)[^\s<>")]+)'
50
 
51
+ SUSPICIOUS_TLDS = {
52
+ ".xyz", ".top", ".click", ".link", ".ru", ".cn", ".country", ".gq", ".ga", ".ml", ".tk"
53
+ }
54
  SHORTENERS = {"bit.ly","t.co","tinyurl.com","goo.gl","ow.ly","is.gd","cutt.ly","tiny.one","lnkd.in"}
55
 
56
  def extract_urls(text: str) -> List[str]:
57
  if not text: return []
58
  urls = re.findall(URL_REGEX, text)
59
+ uniq, seen = [], set()
 
 
60
  for u in urls:
61
  u = u.strip().strip(').,;\'"')
62
  if u and u not in seen:
 
97
  _tokenizer = None
98
  _model = None
99
 
100
+ # Strong vs normal cues (lowercase)
101
+ STRONG_CUES = [
102
+ "otp", "one-time password", "one time password", "cvv", "pin", "pan",
103
+ "password", "bank details", "netbanking", "debit card", "credit card",
104
+ "lottery", "jackpot", "prize", "reward", "winner", "you have won",
105
+ "send otp", "share otp", "confirm otp", "verify otp",
106
+ "account restricted", "reactivate account", "unlock your account"
107
  ]
108
 
109
+ NORMAL_CUES = [
110
+ "verify your account", "update your password", "immediately",
111
+ "within 24 hours", "suspended", "unusual activity", "confirm",
112
+ "login", "click", "invoice", "payment", "security alert",
113
+ "urgent", "limited time"
114
+ ]
115
+
116
+ LEXICAL_CUES = sorted(set(STRONG_CUES + NORMAL_CUES))
117
+
118
  def load_email_model() -> Tuple[object, object]:
119
+ """Try to load EMAIL_CLASSIFIER_ID; on failure, fall back to backbone with small head.
120
+ Apply dynamic int8 quantization for CPU if available."""
121
  global _tokenizer, _model
122
  if _tokenizer is not None and _model is not None:
123
  return _tokenizer, _model
124
 
125
  if AutoTokenizer is None or AutoModelForSequenceClassification is None or torch is None:
126
+ return None, None # environment without torch/transformers
 
127
 
128
+ # Preferred classifier
129
  model_id = EMAIL_CLASSIFIER_ID
130
  try:
131
  _tokenizer = AutoTokenizer.from_pretrained(model_id)
132
  _model = AutoModelForSequenceClassification.from_pretrained(model_id)
133
  except Exception:
134
+ # Fallback: backbone + fresh 2-class head
135
  try:
136
  _tokenizer = AutoTokenizer.from_pretrained(EMAIL_BACKBONE_ID)
137
  _model = AutoModelForSequenceClassification.from_pretrained(
 
141
  _tokenizer, _model = None, None
142
  return None, None
143
 
144
+ # Dynamic quantization (CPU)
145
  try:
146
  _model.eval()
147
  _model.to("cpu")
 
160
  return subj + body
161
 
162
  def score_email(subject: str, body: str) -> EmailResult:
163
+ """Return EmailResult with probability + hit lists.
164
+ Strong cues push higher risk even without a model (email-only scams)."""
165
  text = (subject or "") + "\n" + (body or "")
166
+ low = text.lower()
167
+
168
+ strong_hits = [c for c in STRONG_CUES if c in low]
169
+ normal_hits = [c for c in NORMAL_CUES if c in low]
170
+ all_hits = sorted(set(strong_hits + normal_hits))
171
 
172
  tok, mdl = load_email_model()
173
  if tok is None or mdl is None:
174
+ # Pure lexical fallback (no model available):
175
+ base = 0.10
176
+ p_email = base + 0.18 * len(strong_hits) + 0.07 * len(normal_hits)
177
+ p_email = float(max(0.01, min(0.99, p_email)))
178
+ return EmailResult(p_email=p_email, kw_hits=all_hits, strong_hits=strong_hits)
179
 
180
+ # Model path (MiniLM or your classifier)
181
  encoded_subj = tok.encode(subject or "", add_special_tokens=False)
182
  encoded_body = tok.encode(body or "", add_special_tokens=False)
183
+ input_ids = _truncate_for_budget(encoded_subj, encoded_body, MAX_SEQ_LEN - 2, SUBJECT_TOKEN_BUDGET)
184
  input_ids = [tok.cls_token_id] + input_ids + [tok.sep_token_id]
185
+ attn_mask = [1] * len(input_ids)
186
 
 
187
  ids = torch.tensor([input_ids], dtype=torch.long)
188
  mask = torch.tensor([attn_mask], dtype=torch.long)
189
+
190
  with torch.no_grad():
191
  out = mdl(input_ids=ids, attention_mask=mask)
192
  if hasattr(out, "logits"):
 
 
193
  import math
194
+ logits = out.logits[0].detach().cpu().numpy().tolist()
195
  exps = [math.exp(x) for x in logits]
196
+ p1 = exps[1] / (exps[0] + exps[1]) # assume label 1 = phishing
197
  p_email = float(p1)
198
  else:
199
  p_email = 0.5
200
 
201
+ # Nudge with cues: stronger boost for strong hits
202
+ p_email += 0.10 * len(strong_hits) + 0.03 * len(normal_hits)
203
+ p_email = float(max(0.01, min(0.99, p_email)))
204
+
205
+ return EmailResult(p_email=p_email, kw_hits=all_hits, strong_hits=strong_hits)
206
 
207
  # =========================
208
  # Fusion
209
  # =========================
210
  def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains: List[str]) -> Dict:
211
  r_url_max = max([u.risk for u in url_results], default=0.0)
212
+ no_urls = (len(url_results) == 0)
213
 
214
+ # Allowlist check: if any URL host in allowlist (only matters when URLs exist)
215
  allowlist_hit = False
216
  for u in url_results:
217
  h = url_host(u.url)
218
+ if any(h.endswith(d.strip().lower()) for d in allowlist_domains if d.strip()):
219
  allowlist_hit = True
220
  break
221
 
222
+ # Base fusion
223
  r_total = FUSION_EMAIL_W * email_res.p_email + FUSION_URL_W * r_url_max
224
 
225
+ # URL-driven overrides
226
+ kw_flag = 1 if email_res.kw_hits else 0
227
  if (r_url_max >= URL_OVERRIDE_HIGH) or (kw_flag and r_url_max >= URL_OVERRIDE_KW):
228
  r_total = max(r_total, 0.90)
229
 
230
+ # Email-only strong-cue override
231
+ if no_urls and len(email_res.strong_hits) > 0:
232
+ r_total = max(r_total, 0.85)
233
+
234
+ # Allowlist cap
235
  if allowlist_hit:
236
  r_total = min(r_total, ALLOWLIST_SAFE_CAP)
237
 
 
241
  "R_url_max": round(r_url_max, 3),
242
  "R_total": round(r_total, 3),
243
  "kw_hits": email_res.kw_hits,
244
+ "strong_hits": email_res.strong_hits,
245
+ "no_urls": no_urls,
246
  "allowlist_hit": allowlist_hit,
247
  "verdict": verdict
248
  }
 
251
  # Gradio UI
252
  # =========================
253
  with gr.Blocks(title="PhishingMail-Lab") as demo:
254
+ gr.Markdown("# 🧪 PhishingMail‑Lab\n**POC** — Free‑tier friendly hybrid (email + URL) with explainable cues.")
255
 
256
  with gr.Row():
257
  subject = gr.Textbox(label="Subject", placeholder="Subject: Important account update")
258
+ body = gr.Textbox(label="Email Body (paste text or HTML)", lines=12, placeholder="Paste the email content here...")
259
  with gr.Row():
260
+ allowlist = gr.Textbox(label="Allowlist domains (comma-separated)", placeholder="microsoft.com, amazon.in")
261
  tau = gr.Slider(0, 1, value=THRESHOLD_TAU, step=0.01, label="Decision Threshold τ")
262
  analyze_btn = gr.Button("Analyze")
263
 
264
  verdict = gr.Label(label="Verdict")
265
+ # NEW: context banner right under verdict
266
+ context_banner = gr.Markdown(visible=False)
267
  fusion_json = gr.JSON(label="Fusion & Flags")
268
  url_table = gr.Dataframe(headers=["URL","Risk","Reasons"], label="Per‑URL risk (heuristics demo)", interactive=False)
269
 
 
271
  global THRESHOLD_TAU
272
  THRESHOLD_TAU = float(tau_val)
273
 
274
+ # Extract URLs from both subject and body (keeps it simple)
275
+ urls = list(dict.fromkeys(extract_urls((subject_text or "") + "\n" + (body_text or "")))) # uniq & ordered
276
  url_results = score_urls(urls)
277
  allow_domains = [d.strip().lower() for d in (allowlist_text or "").split(",") if d.strip()]
278
 
279
  email_res = score_email(subject_text or "", body_text or "")
280
  fused = fuse(email_res, url_results, allow_domains)
281
 
282
+ # Build banner text/visibility
283
+ banners = []
284
+ if fused.get("no_urls"):
285
+ banners.append("⚠️ **No URLs found** — decision based **only on email body**.")
286
+ if fused.get("allowlist_hit"):
287
+ banners.append("🛈 **Allowlist active** — risk **capped** for trusted domain.")
288
+ banner_text = "<br>".join(banners) if banners else ""
289
+ banner_visible = bool(banners)
290
+
291
  rows = [[u.url, round(u.risk,3), ", ".join(u.reasons)] for u in url_results]
292
+ return fused["verdict"], gr.update(value=banner_text, visible=banner_visible), fused, rows
293
 
294
+ analyze_btn.click(run, [subject, body, allowlist, tau], [verdict, context_banner, fusion_json, url_table])
295
 
296
  if __name__ == "__main__":
297
  demo.launch()