IndraDThor commited on
Commit
25770b6
·
verified ·
1 Parent(s): ecffb11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -31
app.py CHANGED
@@ -1,23 +1,24 @@
1
  import gradio as gr
 
 
2
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
3
- import json # For JSON output
4
- import re # For flexible parsing
5
 
6
- # Load NER model for entity extraction
7
  ner_model_name = "sgarbi/bert-fda-nutrition-ner"
8
  ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
9
  ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
10
  ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
11
 
12
- # Load generative model for summarization (Flan-T5: fast, instruction-tuned)
13
  summary_model_name = "google/flan-t5-base"
14
  summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_name)
15
  summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model_name)
16
 
17
 
 
18
  def generate_summary(entities_text):
19
  """Generate structured summary using Flan-T5 based on extracted entities"""
20
- # Enhanced prompt: More examples for better zero-shot performance
21
  prompt = f"""
22
  Analyze these food ingredients. Output concise bullet points ONLY in this format:
23
  Benefits:
@@ -29,24 +30,36 @@ def generate_summary(entities_text):
29
  Ingredients: {entities_text}
30
  """
31
  inputs = summary_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
32
- outputs = summary_model.generate(**inputs, max_length=250, num_beams=4, temperature=0.8, do_sample=True,
33
- early_stopping=True)
 
 
 
 
 
 
34
  summary = summary_tokenizer.decode(outputs[0], skip_special_tokens=True)
35
 
36
- # Fallback if summary is too short/placeholder/garbage (e.g., [1-1-1] patterns)
37
- if len(summary) < 20 or "[1-3" in summary or re.match(r'^\[.*\]$', summary) or '-' * 3 in summary or all(c in ' -123456789[]' for c in summary.replace('\n', '')):
38
- # Rule-based fallback for common entities (demo-proof)
39
  fallback_benefits = []
40
  fallback_avoid = []
41
- if "beef" in entities_text.lower():
 
42
  fallback_benefits.append("High-quality protein for muscle repair")
43
- fallback_avoid.extend(["Vegans/vegetarians (animal product)", "Children under 5 (choking risk)",
44
- "Gout sufferers (high purines)"])
45
- if "milk" in entities_text.lower():
 
 
 
46
  fallback_benefits.append("Calcium for bone health")
47
- fallback_avoid.extend(
48
- ["Lactose-intolerant (dairy)", "Vegans (animal product)", "Infants under 1 (potential allergy)"])
49
- if "sugar" in entities_text.lower():
 
 
 
50
  fallback_benefits.append("Quick energy source")
51
  fallback_avoid.append("Diabetics (high carbs)")
52
  summary = f"Benefits:\n- {'; '.join(fallback_benefits)}\nAvoid if:\n- {'; '.join(fallback_avoid)}"
@@ -54,51 +67,65 @@ def generate_summary(entities_text):
54
  return summary.strip()
55
 
56
 
57
- def analyze_ingredients(text):
58
  """Analyze text for nutrition NER + generate summary"""
59
  if not text.strip():
60
  return {"error": "No text provided."}
61
 
62
- # Step 1: Extract entities with NER (filter noise like quantities)
63
  ner_results = ner_pipeline(text)
64
- entities = [entity["word"].strip() for entity in ner_results if
65
- entity["score"] > 0.7] # Higher threshold for cleaner tags
66
  entities_text = ", ".join(entities) if entities else "No key ingredients found"
67
 
68
  # Step 2: Generate summary
69
  summary = generate_summary(entities_text)
70
 
71
- # Step 3: Parse summary into arrays (flexible regex for bullets)
72
  benefits = re.findall(r'Benefits:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or []
73
  avoidances = re.findall(r'Avoid if:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or []
74
 
75
- # Format full output as JSON for Flutter
76
  formatted = {
77
  "input_text": text,
78
  "extracted_entities": [
79
  {
80
  "word": entity["word"].strip(),
81
  "entity_type": entity["entity_group"],
82
- "confidence": round(entity["score"], 3)
83
  }
84
- for entity in ner_results if entity["score"] > 0.7
85
  ],
86
  "summary": summary,
87
- "benefits": benefits, # Now populated, e.g., ["High in protein for muscle repair"]
88
- "avoidances": avoidances # e.g., ["Diabetics (high sugar), vegans (beef)"]
89
  }
90
 
91
- return formatted # Gradio auto-JSONifies for API
92
 
93
 
94
- # Gradio interface (web UI + API)
 
 
 
 
 
 
 
 
 
 
 
 
95
  iface = gr.Interface(
96
  fn=analyze_ingredients,
97
  inputs=gr.Textbox(label="Enter Ingredients Text", placeholder="e.g., Wheat, milk 1kg, sugar, nuts"),
98
  outputs=gr.JSON(label="Full Analysis (Entities + Summary)"),
99
  title="Ingredient NER & Health Analyzer",
100
- description="Extracts nutrition entities from food labels and generates a summary of benefits, avoidance warnings (age, health, cultural, allergies)."
101
  )
102
 
 
 
 
103
  if __name__ == "__main__":
104
- iface.launch(server_name="0.0.0.0", server_port=7860) # For HF compatibility
 
 
1
  import gradio as gr
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse
4
  from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification, AutoModelForSeq2SeqLM
5
+ import json
6
+ import re
7
 
8
+ # ---------- Load Models ----------
9
  ner_model_name = "sgarbi/bert-fda-nutrition-ner"
10
  ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
11
  ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name)
12
  ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple")
13
 
 
14
  summary_model_name = "google/flan-t5-base"
15
  summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_name)
16
  summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model_name)
17
 
18
 
19
+ # ---------- Core Logic ----------
20
  def generate_summary(entities_text):
21
  """Generate structured summary using Flan-T5 based on extracted entities"""
 
22
  prompt = f"""
23
  Analyze these food ingredients. Output concise bullet points ONLY in this format:
24
  Benefits:
 
30
  Ingredients: {entities_text}
31
  """
32
  inputs = summary_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
33
+ outputs = summary_model.generate(
34
+ **inputs,
35
+ max_length=250,
36
+ num_beams=4,
37
+ temperature=0.8,
38
+ do_sample=True,
39
+ early_stopping=True
40
+ )
41
  summary = summary_tokenizer.decode(outputs[0], skip_special_tokens=True)
42
 
43
+ # Fallback handling for empty or malformed output
44
+ if len(summary) < 20 or "[1-3" in summary or re.match(r'^\[.*\]$', summary) or '-' * 3 in summary:
 
45
  fallback_benefits = []
46
  fallback_avoid = []
47
+ text = entities_text.lower()
48
+ if "beef" in text:
49
  fallback_benefits.append("High-quality protein for muscle repair")
50
+ fallback_avoid.extend([
51
+ "Vegans/vegetarians (animal product)",
52
+ "Children under 5 (choking risk)",
53
+ "Gout sufferers (high purines)"
54
+ ])
55
+ if "milk" in text:
56
  fallback_benefits.append("Calcium for bone health")
57
+ fallback_avoid.extend([
58
+ "Lactose-intolerant (dairy)",
59
+ "Vegans (animal product)",
60
+ "Infants under 1 (potential allergy)"
61
+ ])
62
+ if "sugar" in text:
63
  fallback_benefits.append("Quick energy source")
64
  fallback_avoid.append("Diabetics (high carbs)")
65
  summary = f"Benefits:\n- {'; '.join(fallback_benefits)}\nAvoid if:\n- {'; '.join(fallback_avoid)}"
 
67
  return summary.strip()
68
 
69
 
70
+ def analyze_ingredients(text: str):
71
  """Analyze text for nutrition NER + generate summary"""
72
  if not text.strip():
73
  return {"error": "No text provided."}
74
 
75
+ # Step 1: Extract entities
76
  ner_results = ner_pipeline(text)
77
+ entities = [entity["word"].strip() for entity in ner_results if float(entity["score"]) > 0.7]
 
78
  entities_text = ", ".join(entities) if entities else "No key ingredients found"
79
 
80
  # Step 2: Generate summary
81
  summary = generate_summary(entities_text)
82
 
83
+ # Step 3: Parse for benefits and avoidances
84
  benefits = re.findall(r'Benefits:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or []
85
  avoidances = re.findall(r'Avoid if:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or []
86
 
 
87
  formatted = {
88
  "input_text": text,
89
  "extracted_entities": [
90
  {
91
  "word": entity["word"].strip(),
92
  "entity_type": entity["entity_group"],
93
+ "confidence": float(round(float(entity["score"]), 3)) # ✅ fix numpy.float32 issue
94
  }
95
+ for entity in ner_results if float(entity["score"]) > 0.7
96
  ],
97
  "summary": summary,
98
+ "benefits": [str(b).strip() for b in benefits],
99
+ "avoidances": [str(a).strip() for a in avoidances]
100
  }
101
 
102
+ return json.loads(json.dumps(formatted, ensure_ascii=False))
103
 
104
 
105
+ # ---------- FastAPI Wrapper ----------
106
+ app = FastAPI(title="Ingredient NER API")
107
+
108
+ @app.post("/api/analyze")
109
+ async def analyze_api(request: Request):
110
+ """JSON API endpoint"""
111
+ body = await request.json()
112
+ text = body.get("text", "")
113
+ result = analyze_ingredients(text)
114
+ return JSONResponse(content=result)
115
+
116
+
117
+ # ---------- Gradio Interface ----------
118
  iface = gr.Interface(
119
  fn=analyze_ingredients,
120
  inputs=gr.Textbox(label="Enter Ingredients Text", placeholder="e.g., Wheat, milk 1kg, sugar, nuts"),
121
  outputs=gr.JSON(label="Full Analysis (Entities + Summary)"),
122
  title="Ingredient NER & Health Analyzer",
123
+ description="Extracts nutrition entities from food labels and generates a summary of benefits and health warnings."
124
  )
125
 
126
+ gr_app = gr.mount_gradio_app(app, iface, path="/")
127
+
128
+ # ---------- Launch ----------
129
  if __name__ == "__main__":
130
+ import uvicorn
131
+ uvicorn.run(app, host="0.0.0.0", port=7860)