import gradio as gr from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification, AutoModelForSeq2SeqLM import json import re # ---------- Load Models ---------- ner_model_name = "sgarbi/bert-fda-nutrition-ner" ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name) ner_pipeline = pipeline("ner", model=ner_model, tokenizer=ner_tokenizer, aggregation_strategy="simple") summary_model_name = "google/flan-t5-base" summary_tokenizer = AutoTokenizer.from_pretrained(summary_model_name) summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model_name) # ---------- Core Logic ---------- def generate_summary(entities_text): """Generate structured summary using Flan-T5 based on extracted entities""" prompt = f""" Analyze these food ingredients. Output concise bullet points ONLY in this format: Benefits: - [1-2 benefits, e.g., High protein for muscle building and energy] Avoid if: - [1-3 warnings for age/health/cultural/allergies, e.g., Infants (choking risk), vegans (animal products), nut allergy] Ingredients: {entities_text} """ inputs = summary_tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) outputs = summary_model.generate( **inputs, max_length=250, num_beams=4, temperature=0.8, do_sample=True, early_stopping=True ) summary = summary_tokenizer.decode(outputs[0], skip_special_tokens=True) # Fallback handling for empty or malformed output if len(summary) < 20 or "[1-3" in summary or re.match(r'^\[.*\]$', summary) or '-' * 3 in summary: fallback_benefits = [] fallback_avoid = [] text = entities_text.lower() if "beef" in text: fallback_benefits.append("High-quality protein for muscle repair") fallback_avoid.extend([ "Vegans/vegetarians (animal product)", "Children under 5 (choking risk)", "Gout sufferers (high purines)" ]) if "milk" in text: fallback_benefits.append("Calcium for bone health") fallback_avoid.extend([ "Lactose-intolerant (dairy)", "Vegans (animal product)", "Infants under 1 (potential allergy)" ]) if "sugar" in text: fallback_benefits.append("Quick energy source") fallback_avoid.append("Diabetics (high carbs)") summary = f"Benefits:\n- {'; '.join(fallback_benefits)}\nAvoid if:\n- {'; '.join(fallback_avoid)}" return summary.strip() def analyze_ingredients(text: str): """Analyze text for nutrition NER + generate summary""" if not text.strip(): return {"error": "No text provided."} # Step 1: Extract entities ner_results = ner_pipeline(text) entities = [entity["word"].strip() for entity in ner_results if float(entity["score"]) > 0.7] entities_text = ", ".join(entities) if entities else "No key ingredients found" # Step 2: Generate summary summary = generate_summary(entities_text) # Step 3: Parse for benefits and avoidances benefits = re.findall(r'Benefits:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or [] avoidances = re.findall(r'Avoid if:\s*-?\s*(.+?)(?=\n-|$)', summary, re.IGNORECASE | re.MULTILINE) or [] formatted = { "input_text": text, "extracted_entities": [ { "word": entity["word"].strip(), "entity_type": entity["entity_group"], "confidence": float(round(float(entity["score"]), 3)) # ✅ fix numpy.float32 issue } for entity in ner_results if float(entity["score"]) > 0.7 ], "summary": summary, "benefits": [str(b).strip() for b in benefits], "avoidances": [str(a).strip() for a in avoidances] } return json.loads(json.dumps(formatted, ensure_ascii=False)) # ---------- FastAPI Wrapper ---------- app = FastAPI(title="Ingredient NER API") @app.post("/api/analyze") async def analyze_api(request: Request): """JSON API endpoint""" body = await request.json() text = body.get("text", "") result = analyze_ingredients(text) return JSONResponse(content=result) # ---------- Gradio Interface ---------- iface = gr.Interface( fn=analyze_ingredients, inputs=gr.Textbox(label="Enter Ingredients Text", placeholder="e.g., Wheat, milk 1kg, sugar, nuts"), outputs=gr.JSON(label="Full Analysis (Entities + Summary)"), title="Ingredient NER & Health Analyzer", description="Extracts nutrition entities from food labels and generates a summary of benefits and health warnings." ) gr_app = gr.mount_gradio_app(app, iface, path="/") # ---------- Launch ---------- if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)