# app.py import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import os import torch import re MODEL_ID = "Muhammadidrees/MedicalInsights" # ----------------------- # Load tokenizer + model safely (GPU or CPU) # ----------------------- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Try a few loading strategies so this works on GPU or CPU Spaces try: # Preferred: let HF decide device placement (works for GPU-enabled Spaces) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) except Exception: # Fallback: force CPU (slower but safe) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32, low_cpu_mem_usage=True) # Create pipeline pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) # ----------------------- # Helper: robust section splitter # ----------------------- def split_report(text): """ Split model output into left (sections 1-4) and right (sections 5-6). Accepts various markers for robustness. """ # Normalize whitespace text = text.strip() # Common markers that indicate tabular/insights section markers = [ "5. Tabular Mapping", "5. Tabular", "Tabular Mapping", "Tabular & AI Insights", "📊 Tabular", "## 5", ] # Find earliest marker occurrence idx = None for m in markers: pos = text.find(m) if pos != -1: if idx is None or pos < idx: idx = pos if idx is None: # fallback: try splitting at "Enhanced AI Insights" or "Enhanced AI" fallback = text.find("Enhanced AI Insights") if fallback == -1: fallback = text.find("Enhanced AI") idx = fallback if fallback != -1 else None if idx is None: # couldn't find a split marker -> put everything in left return text, "" left = text[:idx].strip() right = text[idx:].strip() return left, right # ----------------------- # The analyze function # ----------------------- def analyze( albumin, creatinine, glucose, crp, mcv, rdw, alp, wbc, lymph, age, gender, height, weight ): # Validate BMI try: height = float(height) weight = float(weight) bmi = round(weight / ((height / 100) ** 2), 2) if height > 0 else "N/A" except Exception: bmi = "N/A" # ------------------------- # System prompt (enforce 6 headings) # ------------------------- system_prompt = ( "You are a professional AI Medical Assistant.\n" "You are analyzing patient demographics (age, height, weight) and the Levine biomarker panel.\n\n" "STRICT RULES:\n" "- Use ONLY the 9 biomarkers (Albumin, Creatinine, Glucose, CRP, MCV, RDW, ALP, WBC, Lymphocytes) + Age/Height/Weight.\n" "- Do NOT use or invent other labs (cholesterol, ferritin, vitamin D, etc.).\n" "- If data missing: explicitly write 'Not available from current biomarkers.'\n" "- Always cover ALL SIX SECTIONS with detail:\n" " 1. Executive Summary\n" " 2. System-Specific Analysis\n" " 3. Personalized Action Plan\n" " 4. Interaction Alerts\n" " 5. Tabular Mapping\n" " 6. Enhanced AI Insights & Longitudinal Risk\n" "- Use Markdown formatting for readability.\n" "- Keep tone professional, clear, and client-friendly.\n" "- Tables must be clean Markdown tables.\n" ) # Patient input block patient_input = ( f"Patient Profile:\n" f"- Age: {age}\n" f"- Gender: {gender}\n" f"- Height: {height} cm\n" f"- Weight: {weight} kg\n" f"- BMI: {bmi}\n\n" "Lab Values:\n" f"- Albumin: {albumin} g/dL\n" f"- Creatinine: {creatinine} mg/dL\n" f"- Glucose: {glucose} mg/dL\n" f"- CRP: {crp} mg/L\n" f"- MCV: {mcv} fL\n" f"- RDW: {rdw} %\n" f"- ALP: {alp} U/L\n" f"- WBC: {wbc} K/uL\n" f"- Lymphocytes: {lymph} %\n" ) prompt = system_prompt + "\n" + patient_input # ------------------------- # Generate with strong control # ------------------------- gen = pipe( prompt, max_new_tokens=3000, do_sample=False, # deterministic temperature=0.01, # no randomness top_p=1.0, # cover all tokens repetition_penalty=1.1, # reduce repetition return_full_text=False ) # Extract text generated = gen[0].get("generated_text") or gen[0].get("text") or "" generated = generated.strip() # Remove possible echoes for chunk in [patient_input, system_prompt]: if chunk.strip() in generated: generated = generated.split(chunk.strip())[-1].strip() # Split into panels left_md, right_md = split_report(generated) # Fallback if empty if len(left_md) < 50 and len(right_md) < 50: return ( "⚠️ Model response too short. Please re-run.\n\n**Patient Profile:**\n" + patient_input, "" ) return left_md, right_md # ----------------------- # Build Gradio app # ----------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🏥 AI Medical Biomarker Dashboard") gr.Markdown("Enter lab values and demographics — Report is generated in two panels (Summary & Table/Insights).") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 👤 Demographics") age = gr.Number(label="Age", value=45) gender = gr.Dropdown(["Male", "Female"], label="Gender", value="Male") height = gr.Number(label="Height (cm)", value=174) weight = gr.Number(label="Weight (kg)", value=75) gr.Markdown("### 🩸 Blood Panel") wbc = gr.Number(label="WBC (K/uL)", value=6.5) lymph = gr.Number(label="Lymphocytes (%)", value=30) mcv = gr.Number(label="MCV (fL)", value=88) rdw = gr.Number(label="RDW (%)", value=13) with gr.Column(scale=1): gr.Markdown("### 🧬 Chemistry Panel") albumin = gr.Number(label="Albumin (g/dL)", value=4.2) creatinine = gr.Number(label="Creatinine (mg/dL)", value=0.9) glucose = gr.Number(label="Glucose (mg/dL)", value=92) crp = gr.Number(label="CRP (mg/L)", value=1.0) alp = gr.Number(label="ALP (U/L)", value=70) analyze_btn = gr.Button("🔬 Generate Report", variant="primary") with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📝 Summary & Action Plan") left_output = gr.Markdown(value="Press *Generate Report* to create the analysis.") with gr.Column(scale=1): gr.Markdown("### 📊 Tabular & AI Insights") right_output = gr.Markdown(value="Tabular mapping and enhanced insights will appear here.") # Connect button to function analyze_btn.click( fn=analyze, inputs=[albumin, creatinine, glucose, crp, mcv, rdw, alp, wbc, lymph, age, gender, height, weight], outputs=[left_output, right_output] ) # ------------------------- # Launch app with error visibility # ------------------------- if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)), show_error=True, # 👈 enables full error trace in logs share=False # keep private; set True only for public links )