import gradio as gr import joblib import numpy as np import shap import matplotlib.pyplot as plt # Load the trained XGBoost model model_path = r"xgboost.joblib" loaded_dict = joblib.load(model_path) model = loaded_dict['model'] # 假设模型保存在字典的'model'键下 # Define feature ranges for min-max normalization feature_ranges = { "Age": (8, 96), "WBC(10⁹/L)": (3.21, 16.12), "APTT(s)": (20, 48.8), "TT(s)": (11.4, 39.2), "Surgery: T&T": (0, 1), "Days 1 symptom score": (4, 38) } # Score calculation function def calculate_total_score(skin_temp, skin_color, pain, swelling, homans_sign, vascular_ultrasound): scores = { "normal": 0, "slightly hot": 2, "mild fever": 4, "significantly increased": 6, "reddish": 2, "lavender": 4, "dark red": 6, "Occasionally": 2, "often": 4, "always": 6, "Swelling ≥ 1 cm": 2, "Swelling ≥ 2 cm": 4, "Swelling ≥ 4 cm": 6, "Swelling ≥ 5 cm": 8, "mild pain": 2, "pain": 4, "significant pain": 6, "completely patency again": 0, "Re-patency rate ≥ 80%": 2, "Re-patency rate ≥ 60%": 4, "Re-patency rate ≥ 40%": 6, "Re-patency rate ≥ 20%": 8, "Re-patency rate < 20%": 10, } def get_score(value): if isinstance(value, list): return scores.get(value[0], 0) if value else 0 return scores.get(value, 0) return sum([get_score(skin_temp), get_score(skin_color), get_score(pain), get_score(swelling), get_score(homans_sign), get_score(vascular_ultrasound)]) # Min-max normalization function def normalize(value, min_val, max_val): return (value - min_val) / (max_val - min_val) # Prediction function def predict(age, wbc, aptt, tt, surgery_plan, symptom_inputs): # Calculate symptom score symptom_score = calculate_total_score(*symptom_inputs) # Convert surgery plan to binary surgery = 1 if surgery_plan == "Planned" else 0 # Normalize inputs inputs = [ normalize(age, *feature_ranges["Age"]), normalize(wbc, *feature_ranges["WBC(10⁹/L)"]), normalize(aptt, *feature_ranges["APTT(s)"]), normalize(tt, *feature_ranges["TT(s)"]), normalize(surgery, *feature_ranges["Surgery: T&T"]), normalize(symptom_score, *feature_ranges["Days 1 symptom score"]) ] inputs = np.array(inputs).reshape(1, -1) try: # Make prediction using the model from the loaded dictionary prediction = int(model.predict(inputs)[0]) except AttributeError: # If model is a dictionary containing predictions prediction = int(model[str(list(inputs[0]))]) # SHAP explanation try: explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(inputs) # Round SHAP values to 4 decimal places shap_values = np.round(shap_values, 4) # Create SHAP plot plt.figure(figsize=(10, 2)) shap.force_plot( np.round(explainer.expected_value, 4), shap_values[0], np.round(inputs[0], 4), feature_names=list(feature_ranges.keys()), matplotlib=True, show=False, text_rotation=0 ) plt.tight_layout() # Save plot as image plt.savefig('shap_plot.png', bbox_inches='tight', dpi=150) plt.close() except: # If SHAP explanation fails, create a blank plot plt.figure(figsize=(10, 2)) plt.text(0.5, 0.5, 'SHAP explanation unavailable', horizontalalignment='center', verticalalignment='center') plt.axis('off') plt.savefig('shap_plot.png', bbox_inches='tight', dpi=150) plt.close() # Convert the integer prediction to percentage format return f"Symptom improvement rate: {prediction}%", "shap_plot.png", f"Days 1 symptom score: {symptom_score}" # Gradio interface def interface(age, wbc, aptt, tt, surgery_plan, skin_temp, skin_color, pain, swelling, homans_sign, vascular_ultrasound): symptom_inputs = [skin_temp, skin_color, pain, swelling, homans_sign, vascular_ultrasound] prediction, plot_path, score_text = predict(age, wbc, aptt, tt, surgery_plan, symptom_inputs) return prediction, plot_path, score_text # Build the Gradio UI with gr.Blocks() as demo: gr.Markdown('# DVT Treatment Response Management') # Basic input section gr.Markdown("### Basic Parameters") with gr.Row(): age = gr.Number(label="Age") wbc = gr.Number(label="WBC(10⁹/L)") aptt = gr.Number(label="APTT(s)") tt = gr.Number(label="TT(s)") surgery = gr.Radio( choices=["Planned", "Not Planned"], label="Surgery: T&T", value="Not Planned" ) # Add note about Surgery: T&T gr.Markdown("*Note: Surgery: T&T refers to Thrombolysis or Thrombectomy*") # Symptom score calculation section gr.Markdown("### Symptom Score Parameters") with gr.Row(): skin_temp = gr.Radio( choices=["normal", "slightly hot", "mild fever", "significantly increased"], label="Skin Temperature", value="normal" ) skin_color = gr.Radio( choices=["normal", "reddish", "lavender", "dark red"], label="Skin Color", value="normal" ) with gr.Row(): pain = gr.Radio( choices=["normal", "mild pain", "pain", "significant pain"], label="Pain Level", value="normal" ) swelling = gr.Radio( choices=["normal", "Swelling ≥ 1 cm", "Swelling ≥ 2 cm", "Swelling ≥ 4 cm", "Swelling ≥ 5 cm"], label="Swelling Level", value="normal" ) with gr.Row(): homans_sign = gr.Radio( choices=["normal", "Occasionally", "often", "always"], label="Homans Sign", value="normal" ) vascular_ultrasound = gr.Radio( choices=["completely patency again", "Re-patency rate ≥ 80%", "Re-patency rate ≥ 60%", "Re-patency rate ≥ 40%", "Re-patency rate ≥ 20%", "Re-patency rate < 20%"], label="Vascular Ultrasound", value="completely patency again" ) with gr.Row(): # Add detailed notes about parameters gr.Markdown(""" **Notes on Parameters:** 1. For Skin Temperature: - Normal: Body surface temperature is basically the same as the contralateral limb, usually between 33°C - 35°C - Slightly Hot: Body surface temperature is 0.5°C - 1°C higher than the contralateral limb, i.e. 35.5°C -36°C - Mild Fever: Body surface temperature is 1°C - 2°C higher than the contralateral limb, i.e. 36.5°C - 37°C - Significantly Increased: Body surface temperature is more than 2°C higher than the contralateral limb, i.e. 37.5°C and above 2. For Pain: - Normal: No pain or occasional discomfort - Occasionally: Pain attacks are not frequent, 1-2 times a day - Often: Pain attacks are relatively frequent, 3-5 times a day - Always: Pain persists, occurs almost all day or more than 5 times a day 3. For Homans Sign: - Normal: NRS (Numerical Rating Scale) score 0 - Mild Pain: Slight discomfort when flexing the calf, NRS score 1-3 - Pain: Obvious pain when flexing the calf, NRS score 4-6 - Significant Pain: Severe pain when flexing the calf, NRS score 7-10 """) with gr.Row(): predict_button = gr.Button("Predict") # Output section gr.Markdown("### Prediction Result") with gr.Row(): with gr.Column(): score_output = gr.Textbox(label="Symptom Score at Admission") prediction_output = gr.Textbox(label="Predicted Result") shap_plot = gr.Image(label="SHAP Force Plot") predict_button.click( interface, inputs=[ age, wbc, aptt, tt, surgery, skin_temp, skin_color, pain, swelling, homans_sign, vascular_ultrasound ], outputs=[prediction_output, shap_plot, score_output] ) demo.launch(share=True)