ai-assist-sh commited on
Commit
a380f06
·
verified ·
1 Parent(s): 9fe8482

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +33 -12
  2. main.py +254 -335
  3. requirements.txt +7 -8
README.md CHANGED
@@ -1,12 +1,33 @@
1
- ---
2
- title: PhishingMail — Forensics
3
- emoji: 🛡️
4
- colorFrom: red
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: "4.44.1"
8
- app_file: main.py
9
- pinned: false
10
- ---
11
-
12
- Phishing link analysis with on-screen forensics (tokens, logits, [CLS]) and JSON export.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PhishingMail-Lab
3
+ emoji: 🧪
4
+ colorFrom: gray
5
+ colorTo: blue
6
+ sdk: gradio
7
+ app_file: main.py
8
+ python_version: 3.10
9
+ pinned: false
10
+ ---
11
+
12
+ # PhishingMail‑Lab (POC)
13
+ A lightweight **POC** Space that extends your original project with **email+URL fusion** while staying Hugging Face free‑tier friendly.
14
+
15
+ ## What’s inside
16
+ - Gradio UI
17
+ - URL extraction + heuristic risk (demo)
18
+ - Email classifier with **fallback loader** (MiniLM backbone if your HF checkpoint is missing)
19
+ - Fusion & overrides (weights and τ are configurable)
20
+
21
+ ## Configure
22
+ Set these in **Settings → Variables & secrets**:
23
+ - `EMAIL_CLASSIFIER_ID` → your fine‑tuned MiniLM classifier on HF (e.g. `your-username/mini-phish`)
24
+ - `EMAIL_BACKBONE_ID` → defaults to `microsoft/MiniLM-L6-H384-uncased`
25
+ - `THRESHOLD_TAU` → default `0.40`
26
+
27
+ ## Run locally
28
+ ```bash
29
+ pip install -r requirements.txt
30
+ python main.py
31
+ ```
32
+
33
+ Replace the heuristic URL scoring with your existing URL model + fusion logic when ready.
main.py CHANGED
@@ -1,335 +1,254 @@
1
- import os, re, time, json, urllib.parse
2
- import gradio as gr
3
- import torch
4
- import torch.nn.functional as F
5
-
6
- # Optional robust domain parsing; code falls back if missing.
7
- try:
8
- import tldextract
9
- except Exception:
10
- tldextract = None
11
-
12
- os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
13
-
14
- URL_MODEL_ID = "CrabInHoney/urlbert-tiny-v4-malicious-url-classifier"
15
-
16
- # Force readable labels regardless of model config
17
- ID2LABEL = {0: "benign", 1: "defacement", 2: "malware", 3: "phishing"}
18
-
19
- URL_RE = re.compile(r"""(?xi)\b(?:https?://|www\.)[^\s<>"'()]+""")
20
-
21
- KEYWORDS = {
22
- "phish","login","verify","account","secure","update","bank","wallet",
23
- "password","invoice","pay","reset","support","unlock","confirm"
24
- }
25
- SUSPICIOUS_TLDS = {
26
- "zip","mov","lol","xyz","top","country","link","click","cam","help",
27
- "gq","cf","tk","work","rest","monster","quest","live","io","ly"
28
- }
29
- URL_SHORTENERS = {
30
- "bit.ly","tinyurl.com","t.co","goo.gl","is.gd","buff.ly","ow.ly","rebrand.ly","cutt.ly"
31
- }
32
-
33
- _tok = None
34
- _mdl = None
35
-
36
- # ---------- utils ----------
37
- def _extract_urls(text: str):
38
- raw = [m.group(0).strip() for m in URL_RE.finditer(text or "")]
39
- cleaned = []
40
- for u in raw:
41
- u = u.rstrip(").,;:!?•]}>\"'")
42
- cleaned.append(u)
43
- return sorted(set(cleaned))
44
-
45
- def _load_model():
46
- global _tok, _mdl
47
- if _tok is not None and _mdl is not None:
48
- return _tok, _mdl
49
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
50
- _tok = AutoTokenizer.from_pretrained(URL_MODEL_ID)
51
- _mdl = AutoModelForSequenceClassification.from_pretrained(URL_MODEL_ID)
52
- _mdl.eval()
53
- return _tok, _mdl
54
-
55
- def _softmax(logits: torch.Tensor):
56
- return F.softmax(logits, dim=-1).tolist()
57
-
58
- def _results_table(rows):
59
- lines = [
60
- "| URL | Model | Model Prob (%) | Heuristic | Fused Risk | Decision | Reasons |",
61
- "|---|---|---:|---:|---:|:--:|---|",
62
- ]
63
- for r in rows:
64
- u, lbl, pct, h, fused, decision, reasons = r
65
- lines.append(
66
- f"| `{u}` | **{lbl}** | {pct:.2f} | {h:.2f} | {fused:.2f} | {decision} | {reasons} |"
67
- )
68
- return "\n".join(lines)
69
-
70
- def _forensic_block(url, token_ids, tokens, scores_sorted, cls_vec, elapsed_s, truncated):
71
- toks_prev = ", ".join(tokens[:64]) + (" …" if len(tokens) > 64 else "")
72
- ids_prev = ", ".join(map(str, token_ids[:64])) + (" …" if len(token_ids) > 64 else "")
73
- cls_dim = len(cls_vec)
74
- cls_prev = ", ".join(f"{v:.4f}" for v in cls_vec[:16]) + (" …" if cls_dim > 16 else "")
75
- l2 = (sum(v*v for v in cls_vec)) ** 0.5
76
- md = []
77
- md.append(f"### 🔍 Forensics for `{url}`\n")
78
- md.append(f"- tokens: **{len(tokens)}** • truncated: **{'yes' if truncated else 'no'}**")
79
- md.append(f"- inference time: **{elapsed_s:.2f}s**\n")
80
- md.append("**Top-k scores**")
81
- md.append("| Class | Prob (%) | Logit |\n|---|---:|---:|")
82
- for s in scores_sorted:
83
- md.append(f"| **{s['label']}** | {s['prob']*100:.2f} | {s['logit']:.3f} |")
84
- md.append("\n**Token IDs (preview)**")
85
- md.append("```txt\n" + ids_prev + "\n```")
86
- md.append("**Tokens (preview)**")
87
- md.append("```txt\n" + toks_prev + "\n```")
88
- md.append("**[CLS] embedding (preview)**")
89
- md.append(f"`dim={cls_dim}`, `L2={l2:.4f}`")
90
- md.append("```txt\n" + cls_prev + "\n```")
91
- return "\n".join(md)
92
-
93
- # ---------- heuristics ----------
94
- def _safe_parse(url: str):
95
- if not re.match(r"^https?://", url, re.I):
96
- url = "http://" + url
97
- return urllib.parse.urlparse(url)
98
-
99
- def _split_reg_domain(host: str):
100
- parts = host.split(".")
101
- if len(parts) >= 2:
102
- return parts[-2] + "." + parts[-1]
103
- return host
104
-
105
- def _domain_parts(host: str):
106
- if tldextract:
107
- ext = tldextract.extract(host) # subdomain, domain, suffix
108
- regdom = f"{ext.domain}.{ext.suffix}" if ext.domain and ext.suffix else host
109
- sub = ext.subdomain or ""
110
- tld = ext.suffix or ""
111
- core = ext.domain or ""
112
- else:
113
- regdom = _split_reg_domain(host)
114
- tld = regdom.split(".")[-1] if "." in regdom else ""
115
- sub = host[:-len(regdom)].rstrip(".") if host.endswith(regdom) else ""
116
- core = regdom.split(".")[0] if "." in regdom else regdom
117
- return regdom, sub, core, tld
118
-
119
- def heuristic_features(u: str):
120
- feats = {}
121
- try:
122
- p = _safe_parse(u)
123
- feats["host"] = p.hostname or ""
124
- feats["path"] = p.path or "/"
125
- feats["query"] = p.query or ""
126
- regdom, sub, core, tld = _domain_parts(feats["host"])
127
- feats["registered_domain"] = regdom
128
- feats["subdomain"] = sub
129
- feats["tld"] = tld
130
- feats["labels"] = feats["host"].count(".") + (1 if feats["host"] else 0)
131
- feats["has_at"] = "@" in u
132
- feats["has_port"] = bool(p.netloc and ":" in p.netloc.split("@")[-1])
133
- feats["has_punycode"] = "xn--" in feats["host"]
134
- feats["len_url"] = len(u)
135
- feats["hyphen_in_regdom"] = "-" in (core or "")
136
- low_host = feats["host"].lower()
137
- low_path = feats["path"].lower()
138
- feats["kw_in_path"] = int(any(k in low_path for k in KEYWORDS))
139
- feats["kw_in_host"] = int(any(k in low_host for k in KEYWORDS))
140
- feats["kw_in_subdomain_only"] = int(
141
- feats["kw_in_host"] and (core and not any(k in (core.lower()) for k in KEYWORDS))
142
- )
143
- feats["suspicious_tld"] = int((feats["tld"].split(".")[-1] or "") in SUSPICIOUS_TLDS)
144
- feats["is_shortener"] = int(regdom.lower() in URL_SHORTENERS)
145
- alnum = sum(c.isalnum() for c in feats["query"])
146
- feats["query_ratio_alnum"] = (alnum / max(1, len(feats["query"]))) if feats["query"] else 0.0
147
- feats["parse_error"] = False
148
- except Exception:
149
- feats = {"parse_error": True}
150
- return feats
151
-
152
- def heuristic_score(feats: dict) -> float:
153
- if feats.get("parse_error"):
154
- return 0.80
155
- s = 0.0
156
- s += 0.28 * feats["kw_in_path"]
157
- s += 0.24 * feats["kw_in_subdomain_only"]
158
- s += 0.10 * feats["kw_in_host"]
159
- s += 0.12 * feats["hyphen_in_regdom"]
160
- s += 0.10 * (feats["labels"] >= 4)
161
- s += 0.10 * feats["has_punycode"]
162
- s += 0.12 * feats["suspicious_tld"]
163
- s += 0.10 * feats["is_shortener"]
164
- s += 0.05 * feats["has_at"]
165
- s += 0.05 * feats["has_port"]
166
- s += 0.10 * (feats["len_url"] >= 100)
167
- if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
168
- s += 0.10
169
- return max(0.0, min(1.0, s))
170
-
171
- def heuristic_reasons(feats: dict) -> str:
172
- if feats.get("parse_error"):
173
- return "parse error"
174
- rs = []
175
- if feats.get("is_shortener"): rs.append("URL shortener")
176
- if feats.get("kw_in_path"): rs.append("keyword in path")
177
- if feats.get("kw_in_subdomain_only"): rs.append("keyword in subdomain")
178
- if feats.get("kw_in_host") and not feats.get("kw_in_subdomain_only"): rs.append("keyword in host")
179
- if feats.get("hyphen_in_regdom"): rs.append("hyphen in registered domain")
180
- if feats.get("labels", 0) >= 4: rs.append("deep subdomain nesting")
181
- if feats.get("has_punycode"): rs.append("punycode host")
182
- if feats.get("suspicious_tld"): rs.append(f"suspicious TLD: {feats.get('tld')}")
183
- if feats.get("has_at"): rs.append("@ in URL")
184
- if feats.get("has_port"): rs.append("explicit port")
185
- if feats.get("len_url", 0) >= 100: rs.append("very long URL") # ✅ fixed
186
- if feats.get("query") and len(feats.get("query", "")) >= 40 and feats.get("query_ratio_alnum", 0) > 0.9:
187
- rs.append("long query blob")
188
- return ", ".join(rs) if rs else "no heuristic triggers"
189
-
190
- def heuristic_hard_flag(feats: dict) -> (bool, str):
191
- if feats.get("parse_error"):
192
- return True, "unparsable URL"
193
- if feats.get("kw_in_subdomain_only") and feats.get("kw_in_path"):
194
- return True, "keyword in subdomain + keyword in path"
195
- if feats.get("is_shortener") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
196
- return True, "URL shortener + keyword"
197
- if feats.get("suspicious_tld") and (feats.get("kw_in_host") or feats.get("kw_in_path")):
198
- return True, "suspicious TLD + keyword"
199
- if feats.get("labels", 0) >= 4 and (feats.get("kw_in_host") or feats.get("kw_in_path")):
200
- return True, "deep subdomain nesting + keyword"
201
- return False, ""
202
-
203
- # ---------- core ----------
204
- def _parse_allowlist(s: str):
205
- items = re.split(r"[,\s]+", (s or "").strip())
206
- return {x.strip().lower() for x in items if x.strip()}
207
-
208
- def analyze(
209
- text: str,
210
- forensic: bool,
211
- show_json: bool,
212
- threshold: float,
213
- allowlist_txt: str,
214
- allowlist_override: bool
215
- ):
216
- """
217
- One Markdown output:
218
- - verdict + table (model, heuristic, fused + decision + reasons)
219
- - optional forensic blocks
220
- - optional raw JSON
221
- """
222
- text = (text or "").strip()
223
- if not text:
224
- return "Paste an email body or a URL."
225
-
226
- urls = [text] if (text.lower().startswith(("http://","https://","www.")) and " " not in text) else _extract_urls(text)
227
- if not urls:
228
- return "No URLs detected in the text."
229
-
230
- allowset = _parse_allowlist(allowlist_txt)
231
-
232
- tok, mdl = _load_model()
233
-
234
- rows = []
235
- forensic_blocks = []
236
- export_data = {"model_id": URL_MODEL_ID, "items": []}
237
- any_unsafe = False
238
-
239
- for u in urls:
240
- # model forward
241
- max_len = min(512, getattr(mdl.config, "max_position_embeddings", 512) or 512)
242
- enc = tok(u, truncation=True, max_length=max_len, return_tensors="pt", return_attention_mask=True)
243
- token_ids = enc["input_ids"][0].tolist()
244
- tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
245
- truncated = enc["input_ids"].shape[1] >= max_len and len(tokens) >= max_len
246
-
247
- t0 = time.time()
248
- with torch.no_grad():
249
- out = mdl(**enc, output_hidden_states=True)
250
- elapsed = time.time() - t0
251
-
252
- logits = out.logits.squeeze(0)
253
- probs = _softmax(logits)
254
- scores = [{"label": ID2LABEL[i], "prob": float(probs[i]), "logit": float(logits[i])}
255
- for i in range(len(probs))]
256
- scores_sorted = sorted(scores, key=lambda x: x["prob"], reverse=True)
257
- top = scores_sorted[0]
258
-
259
- # heuristics
260
- feats = heuristic_features(u)
261
- regdom = feats.get("registered_domain", "").lower()
262
- h_flag, h_reason = heuristic_hard_flag(feats)
263
- h_score = heuristic_score(feats)
264
- mdl_phish_like = sum(s["prob"] for s in scores_sorted if s["label"] in {"phishing","malware","defacement"})
265
- fused = 0.50 * mdl_phish_like + 0.50 * h_score
266
-
267
- # allowlist override (domain-based)
268
- allow_hit = regdom in allowset if regdom else False
269
- decision = "🛑 UNSAFE"
270
- reasons = (h_reason + (", " if h_reason else "") + heuristic_reasons(feats)).strip(", ")
271
-
272
- if allow_hit and allowlist_override:
273
- decision = "✅ SAFE"
274
- reasons = f"allowlisted domain ({regdom})"
275
- fused = min(fused, 0.01) # clamp down the risk for display
276
- else:
277
- decision = "🛑 UNSAFE" if (h_flag or fused >= float(threshold)) else "✅ SAFE"
278
-
279
- if decision.startswith("🛑"):
280
- any_unsafe = True
281
-
282
- rows.append([u, top["label"], top["prob"]*100.0, h_score, fused, decision, reasons])
283
-
284
- # export + forensics
285
- hidden_states = out.hidden_states
286
- cls_vec = hidden_states[-1][0, 0, :].cpu().tolist()
287
- export_data["items"].append({
288
- "url": u, "token_ids": token_ids, "tokens": tokens, "truncated": truncated,
289
- "logits": [float(x) for x in logits.cpu().tolist()], "probs": [float(p) for p in probs],
290
- "scores_sorted": scores_sorted, "cls_vector": cls_vec, "cls_dim": len(cls_vec),
291
- "elapsed_sec": elapsed, "heuristic": feats, "heuristic_score": h_score,
292
- "fused_risk": fused, "hard_flag": h_flag, "hard_reason": h_reason,
293
- "allowlisted": allow_hit
294
- })
295
-
296
- if forensic:
297
- forensic_blocks.append(
298
- _forensic_block(u, token_ids, tokens, scores_sorted, cls_vec, elapsed, truncated)
299
- )
300
-
301
- verdict = "🔴 **UNSAFE (at least one link flagged)**" if any_unsafe else "🟢 **SAFE (no link over threshold)**"
302
- body = verdict + "\n\n" + _results_table(rows)
303
-
304
- if forensic and forensic_blocks:
305
- body += "\n\n---\n\n" + "\n\n---\n\n".join(forensic_blocks)
306
-
307
- if show_json:
308
- pretty = json.dumps(export_data, ensure_ascii=False, indent=2)
309
- body += "\n\n---\n\n**Raw forensics JSON (copy & save):**\n"
310
- body += "```json\n" + pretty + "\n```"
311
-
312
- return body
313
-
314
- # ---------- UI ----------
315
- demo = gr.Interface(
316
- fn=analyze,
317
- inputs=[
318
- gr.Textbox(lines=10, label="Email or URL", placeholder="Paste a URL or a full email…"),
319
- gr.Checkbox(label="Forensic mode (tokens, logits, [CLS])", value=True),
320
- gr.Checkbox(label="Show raw JSON at the end (copy/paste)", value=False),
321
- gr.Slider(0.0, 1.0, value=0.40, step=0.01, label="Decision threshold (fused risk ≥ threshold → UNSAFE)"),
322
- gr.Textbox(lines=2, label="Allowlist (domains, comma/space/newline separated)",
323
- placeholder="example.com, github.com microsoft.com"),
324
- gr.Checkbox(label="Allowlist overrides (force SAFE if registered domain matches)", value=True),
325
- ],
326
- outputs=gr.Markdown(label="Results"),
327
- title="🛡️ PhishingMail — Model + Heuristics (HF Free CPU)",
328
- description=(
329
- "Extract links, score with a tiny HF URL model and transparent heuristics. "
330
- "Short-circuits for classic phishing patterns. Adjust the threshold, and allowlist trusted domains."
331
- ),
332
- )
333
-
334
- if __name__ == "__main__":
335
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
1
+
2
+ import os, re, json
3
+ from dataclasses import dataclass
4
+ from typing import List, Dict, Tuple
5
+
6
+ import gradio as gr
7
+
8
+ # Optional imports for email classifier (loaded lazily)
9
+ try:
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
12
+ except Exception:
13
+ torch = None
14
+ AutoTokenizer = None
15
+ AutoModelForSequenceClassification = None
16
+
17
+ # =========================
18
+ # Config (env-overridable)
19
+ # =========================
20
+ EMAIL_CLASSIFIER_ID = os.getenv("EMAIL_CLASSIFIER_ID", "your-username/mini-phish") # swap to your HF model repo later
21
+ EMAIL_BACKBONE_ID = os.getenv("EMAIL_BACKBONE_ID", "microsoft/MiniLM-L6-H384-uncased")
22
+ THRESHOLD_TAU = float(os.getenv("THRESHOLD_TAU", "0.40"))
23
+ MAX_SEQ_LEN = int(os.getenv("MAX_SEQ_LEN", "320"))
24
+ SUBJECT_TOKEN_BUDGET= int(os.getenv("SUBJECT_TOKEN_BUDGET", "64"))
25
+ FUSION_EMAIL_W = float(os.getenv("FUSION_EMAIL_W", "0.6"))
26
+ FUSION_URL_W = float(os.getenv("FUSION_URL_W", "0.4"))
27
+ URL_OVERRIDE_HIGH = float(os.getenv("URL_OVERRIDE_HIGH", "0.85"))
28
+ URL_OVERRIDE_KW = float(os.getenv("URL_OVERRIDE_KW", "0.70"))
29
+ ALLOWLIST_SAFE_CAP = float(os.getenv("ALLOWLIST_SAFE_CAP", "0.15"))
30
+
31
+ # =========================
32
+ # Simple data classes
33
+ # =========================
34
+ @dataclass
35
+ class UrlResult:
36
+ url: str
37
+ risk: float
38
+ reasons: List[str]
39
+
40
+ @dataclass
41
+ class EmailResult:
42
+ p_email: float
43
+ kw_hits: List[str]
44
+
45
+ # =========================
46
+ # URL extraction & heuristics (replace with your existing pipeline)
47
+ # =========================
48
+ URL_REGEX = r'(?i)\b((?:https?://|www\.)[^\s<>")]+)'
49
+
50
+ SUSPICIOUS_TLDS = {".xyz", ".top", ".click", ".link", ".ru", ".cn", ".country", ".gq", ".ga", ".ml", ".tk"}
51
+ SHORTENERS = {"bit.ly","t.co","tinyurl.com","goo.gl","ow.ly","is.gd","cutt.ly","tiny.one","lnkd.in"}
52
+
53
+ def extract_urls(text: str) -> List[str]:
54
+ if not text: return []
55
+ urls = re.findall(URL_REGEX, text)
56
+ # normalize
57
+ uniq = []
58
+ seen = set()
59
+ for u in urls:
60
+ u = u.strip().strip(').,;\'"')
61
+ if u and u not in seen:
62
+ uniq.append(u)
63
+ seen.add(u)
64
+ return uniq
65
+
66
+ def url_host(url: str) -> str:
67
+ host = re.sub(r"^https?://", "", url, flags=re.I).split("/")[0].lower()
68
+ return host
69
+
70
+ def score_url_heuristic(url: str) -> UrlResult:
71
+ host = url_host(url)
72
+ score = 0.05
73
+ reasons = []
74
+
75
+ if len(url) > 140:
76
+ score += 0.15; reasons.append("very_long_url")
77
+ if "@" in url or "%" in url:
78
+ score += 0.2; reasons.append("special_chars")
79
+ if any(host.endswith(t) for t in SUSPICIOUS_TLDS):
80
+ score += 0.35; reasons.append("suspicious_tld")
81
+ if any(s in host for s in SHORTENERS):
82
+ score += 0.5; reasons.append("shortener")
83
+ if host.count(".") >= 3:
84
+ score += 0.2; reasons.append("deep_subdomain")
85
+ if len(re.findall(r"[A-Z]", url)) > 16:
86
+ score += 0.1; reasons.append("mixed_case")
87
+
88
+ return UrlResult(url=url, risk=min(score, 1.0), reasons=reasons)
89
+
90
+ def score_urls(urls: List[str]) -> List[UrlResult]:
91
+ return [score_url_heuristic(u) for u in urls]
92
+
93
+ # =========================
94
+ # Email classifier with fallback
95
+ # =========================
96
+ _tokenizer = None
97
+ _model = None
98
+
99
+ LEXICAL_CUES = [
100
+ "verify your account","update your password","immediately","within 24 hours",
101
+ "suspended","unusual activity","confirm","login","click","invoice","payment",
102
+ "otp","one-time password","unlock","reactivate","restricted","authenticate",
103
+ "security alert","urgent","limited time"
104
+ ]
105
+
106
+ def load_email_model() -> Tuple[object, object]:
107
+ global _tokenizer, _model
108
+ if _tokenizer is not None and _model is not None:
109
+ return _tokenizer, _model
110
+
111
+ if AutoTokenizer is None or AutoModelForSequenceClassification is None or torch is None:
112
+ # environment without torch/transformers (Space will still boot)
113
+ return None, None
114
+
115
+ # Try the preferred classifier first
116
+ model_id = EMAIL_CLASSIFIER_ID
117
+ try:
118
+ _tokenizer = AutoTokenizer.from_pretrained(model_id)
119
+ _model = AutoModelForSequenceClassification.from_pretrained(model_id)
120
+ except Exception:
121
+ # Fallback: load backbone and attach a tiny random head
122
+ try:
123
+ _tokenizer = AutoTokenizer.from_pretrained(EMAIL_BACKBONE_ID)
124
+ _model = AutoModelForSequenceClassification.from_pretrained(
125
+ EMAIL_BACKBONE_ID, num_labels=2, problem_type="single_label_classification"
126
+ )
127
+ except Exception:
128
+ _tokenizer, _model = None, None
129
+ return None, None
130
+
131
+ # Dynamic quantization for CPU
132
+ try:
133
+ _model.eval()
134
+ _model.to("cpu")
135
+ if hasattr(torch, "quantization"):
136
+ from torch.quantization import quantize_dynamic
137
+ _model = quantize_dynamic(_model, {torch.nn.Linear}, dtype=torch.qint8) # type: ignore
138
+ except Exception:
139
+ pass
140
+
141
+ return _tokenizer, _model
142
+
143
+ def _truncate_for_budget(tokens_subject: List[int], tokens_body: List[int], max_len: int, subj_budget: int):
144
+ subj = tokens_subject[:subj_budget]
145
+ remain = max(0, max_len - len(subj))
146
+ body = tokens_body[:remain]
147
+ return subj + body
148
+
149
+ def score_email(subject: str, body: str) -> EmailResult:
150
+ text = (subject or "") + "\n" + (body or "")
151
+ # lightweight lexical cues for reasons + kw_flag
152
+ hits = [c for c in LEXICAL_CUES if c in text.lower()]
153
+
154
+ tok, mdl = load_email_model()
155
+ if tok is None or mdl is None:
156
+ # fallback purely lexical probability
157
+ base = 0.15 + 0.1 * len(hits)
158
+ return EmailResult(p_email=float(min(base, 0.99)), kw_hits=hits)
159
+
160
+ # tokenize with budget
161
+ encoded_subj = tok.encode(subject or "", add_special_tokens=False)
162
+ encoded_body = tok.encode(body or "", add_special_tokens=False)
163
+ input_ids = _truncate_for_budget(encoded_subj, encoded_body, MAX_SEQ_LEN-2, SUBJECT_TOKEN_BUDGET)
164
+ input_ids = [tok.cls_token_id] + input_ids + [tok.sep_token_id]
165
+ attn_mask = [1]*len(input_ids)
166
+
167
+ import torch
168
+ ids = torch.tensor([input_ids], dtype=torch.long)
169
+ mask = torch.tensor([attn_mask], dtype=torch.long)
170
+ with torch.no_grad():
171
+ out = mdl(input_ids=ids, attention_mask=mask)
172
+ if hasattr(out, "logits"):
173
+ logits = out.logits[0].detach().cpu().numpy().tolist()
174
+ # assume label 1 = phishing (prob via softmax)
175
+ import math
176
+ exps = [math.exp(x) for x in logits]
177
+ p1 = exps[1] / (exps[0] + exps[1])
178
+ p_email = float(p1)
179
+ else:
180
+ p_email = 0.5
181
+
182
+ # small calibration nudge from lexical cues (kept light)
183
+ p_email = float(min(0.99, max(0.01, p_email + 0.03*len(hits))))
184
+ return EmailResult(p_email=p_email, kw_hits=hits)
185
+
186
+ # =========================
187
+ # Fusion
188
+ # =========================
189
+ def fuse(email_res: EmailResult, url_results: List[UrlResult], allowlist_domains: List[str]) -> Dict:
190
+ r_url_max = max([u.risk for u in url_results], default=0.0)
191
+ kw_flag = 1 if email_res.kw_hits else 0
192
+
193
+ # Allowlist check: if any URL host in allowlist
194
+ allowlist_hit = False
195
+ for u in url_results:
196
+ h = url_host(u.url)
197
+ if any(h.endswith(d.lower()) for d in allowlist_domains):
198
+ allowlist_hit = True
199
+ break
200
+
201
+ r_total = FUSION_EMAIL_W * email_res.p_email + FUSION_URL_W * r_url_max
202
+
203
+ if (r_url_max >= URL_OVERRIDE_HIGH) or (kw_flag and r_url_max >= URL_OVERRIDE_KW):
204
+ r_total = max(r_total, 0.90)
205
+
206
+ if allowlist_hit:
207
+ r_total = min(r_total, ALLOWLIST_SAFE_CAP)
208
+
209
+ verdict = "UNSAFE" if r_total >= THRESHOLD_TAU else "SAFE"
210
+ return {
211
+ "P_email": round(email_res.p_email, 3),
212
+ "R_url_max": round(r_url_max, 3),
213
+ "R_total": round(r_total, 3),
214
+ "kw_hits": email_res.kw_hits,
215
+ "allowlist_hit": allowlist_hit,
216
+ "verdict": verdict
217
+ }
218
+
219
+ # =========================
220
+ # Gradio UI
221
+ # =========================
222
+ with gr.Blocks(title="PhishingMail-Lab") as demo:
223
+ gr.Markdown("# 🧪 PhishingMail‑Lab\nFree‑tier friendly POC with email+URL fusion")
224
+
225
+ with gr.Row():
226
+ subject = gr.Textbox(label="Subject", placeholder="Subject: Important account update")
227
+ body = gr.Textbox(label="Email Body (paste text or HTML)", lines=10, placeholder="Paste the email content here...")
228
+ with gr.Row():
229
+ allowlist = gr.Textbox(label="Allowlist domains (comma-separated)", placeholder="microsoft.com, amazon.com")
230
+ tau = gr.Slider(0, 1, value=THRESHOLD_TAU, step=0.01, label="Decision Threshold τ")
231
+ analyze_btn = gr.Button("Analyze")
232
+
233
+ verdict = gr.Label(label="Verdict")
234
+ fusion_json = gr.JSON(label="Fusion & Flags")
235
+ url_table = gr.Dataframe(headers=["URL","Risk","Reasons"], label="Per‑URL risk (heuristics demo)", interactive=False)
236
+
237
+ def run(subject_text, body_text, allowlist_text, tau_val):
238
+ global THRESHOLD_TAU
239
+ THRESHOLD_TAU = float(tau_val)
240
+
241
+ urls = list(dict.fromkeys(extract_urls((subject_text or "") + "\n" + (body_text or "")))) # uniq while preserving order
242
+ url_results = score_urls(urls)
243
+ allow_domains = [d.strip().lower() for d in (allowlist_text or "").split(",") if d.strip()]
244
+
245
+ email_res = score_email(subject_text or "", body_text or "")
246
+ fused = fuse(email_res, url_results, allow_domains)
247
+
248
+ rows = [[u.url, round(u.risk,3), ", ".join(u.reasons)] for u in url_results]
249
+ return fused["verdict"], fused, rows
250
+
251
+ analyze_btn.click(run, [subject, body, allowlist, tau], [verdict, fusion_json, url_table])
252
+
253
+ if __name__ == "__main__":
254
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- gradio==4.44.1
2
- transformers==4.55.2
3
-
4
- # optional but recommended; code falls back if missing
5
- tldextract==5.1.2
6
-
7
- --extra-index-url https://download.pytorch.org/whl/cpu
8
- torch==2.4.0+cpu
 
1
+ gradio>=4.19,<5
2
+ transformers>=4.41,<4.45
3
+ torch>=2.2,<2.4
4
+ tokenizers>=0.15,<0.20
5
+ beautifulsoup4>=4.12,<5
6
+ tldextract>=3.6,<4
7
+ emoji>=2.10,<3