File size: 4,108 Bytes
e2028f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# app.py
import os
import io
import zipfile
import shutil
import tempfile
import pathlib

import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import autogluon.multimodal as agmm

MODEL_REPO_ID = "its-zion-18/sign-image-autogluon-predictor"
ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
PREVIEW_SIZE = (224, 224)
MAX_UPLOAD_BYTES = 20 * 1024 * 1024  # Allow up to 20 MB now

ex1_path = 'IMG_0059.png'
ex2_path = 'IMG_0064.png'
ex3_path = 'IMG_8689.jpg'
ex1 = Image.open(ex1_path)
ex2 = Image.open(ex2_path)
ex3 = Image.open(ex3_path)
EXAMPLE_IMAGES = [ex1, ex2, ex3]

CLASS_LABELS = {0: "Does not have stop sign", 1: "Has stop sign"}

# Download & load predictor
def _download_and_extract_predictor() -> str:
    CACHE_DIR.mkdir(parents=True, exist_ok=True)
    local_zip = hf_hub_download(
        repo_id=MODEL_REPO_ID,
        filename=ZIP_FILENAME,
        repo_type="model",
        local_dir=str(CACHE_DIR),
        local_dir_use_symlinks=False,
    )
    if EXTRACT_DIR.exists():
        shutil.rmtree(EXTRACT_DIR)
    EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(local_zip, "r") as zf:
        zf.extractall(str(EXTRACT_DIR))
    contents = list(EXTRACT_DIR.iterdir())
    predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
    return str(predictor_root)

def load_predictor() -> agmm.MultiModalPredictor:
    predictor_root = _download_and_extract_predictor()
    return agmm.MultiModalPredictor.load(predictor_root)

PREDICTOR = load_predictor()

# Helpers
def pil_preprocess_preview(pil_img: Image.Image, target_size=PREVIEW_SIZE) -> Image.Image:
    return pil_img.convert("RGB").resize(target_size, Image.BILINEAR)

def run_predict_binary(predictor, pil_img: Image.Image):
    tmpd = pathlib.Path(tempfile.mkdtemp())
    tmp_path = tmpd / "input.png"
    pil_img.save(tmp_path)
    input_df = pd.DataFrame({"image": [str(tmp_path)]})

    probs_df = predictor.predict_proba(input_df)
    row = probs_df.iloc[0]

    # Map to {label string: probability}
    prob_dict = {
        CLASS_LABELS[0]: float(row[0]),
        CLASS_LABELS[1]: float(row[1]),
    }
    # Pick higher one
    pred_label = CLASS_LABELS[int(row.idxmax())]

    try:
        shutil.rmtree(tmpd)
    except Exception:
        pass

    return pred_label, prob_dict

# Gradio callback
def infer_and_display(image: Image.Image):
    if image is None:
        return None, None, "No image provided.", {}

    # Resize large uploads automatically
    bio = io.BytesIO()
    image.save(bio, format="PNG")
    if len(bio.getvalue()) > MAX_UPLOAD_BYTES:
        max_side = 1024
        image.thumbnail((max_side, max_side))

    preview = pil_preprocess_preview(image, PREVIEW_SIZE)
    pred_label, probs = run_predict_binary(PREDICTOR, image)
    return image, preview, f"Prediction: {pred_label}", probs

# Build Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Stop Sign Detection — AutoGluon Predictor")
    gr.Markdown(
        "Upload an image or pick one of the examples. "
        "The app shows the original and preprocessed images, and predicts whether the image **has a stop sign**."
    )

    with gr.Row():
        with gr.Column(scale=1):
            image_in = gr.Image(type="pil", label="Upload an image", sources="upload")
            run_btn = gr.Button("Run inference")
            gr.Examples(EXAMPLE_IMAGES, inputs=[image_in], label="Example images", cache_examples=False)
        with gr.Column(scale=1):
            gr.Markdown("**Original image**")
            orig_out = gr.Image(type="pil")
            gr.Markdown("**Preprocessed image (preview)**")
            pre_out = gr.Image(type="pil")

    out_text = gr.Textbox(label="Prediction", interactive=False)
    proba_label = gr.Label(label="Class probabilities")

    run_btn.click(fn=infer_and_display, inputs=[image_in], outputs=[orig_out, pre_out, out_text, proba_label])

if __name__ == "__main__":
    demo.launch()