Spaces:
Sleeping
Sleeping
File size: 4,108 Bytes
e2028f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# app.py
import os
import io
import zipfile
import shutil
import tempfile
import pathlib
import pandas as pd
from PIL import Image
import gradio as gr
from huggingface_hub import hf_hub_download
import autogluon.multimodal as agmm
MODEL_REPO_ID = "its-zion-18/sign-image-autogluon-predictor"
ZIP_FILENAME = "autogluon_image_predictor_dir.zip"
CACHE_DIR = pathlib.Path("hf_assets")
EXTRACT_DIR = CACHE_DIR / "predictor_native"
PREVIEW_SIZE = (224, 224)
MAX_UPLOAD_BYTES = 20 * 1024 * 1024 # Allow up to 20 MB now
ex1_path = 'IMG_0059.png'
ex2_path = 'IMG_0064.png'
ex3_path = 'IMG_8689.jpg'
ex1 = Image.open(ex1_path)
ex2 = Image.open(ex2_path)
ex3 = Image.open(ex3_path)
EXAMPLE_IMAGES = [ex1, ex2, ex3]
CLASS_LABELS = {0: "Does not have stop sign", 1: "Has stop sign"}
# Download & load predictor
def _download_and_extract_predictor() -> str:
CACHE_DIR.mkdir(parents=True, exist_ok=True)
local_zip = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=ZIP_FILENAME,
repo_type="model",
local_dir=str(CACHE_DIR),
local_dir_use_symlinks=False,
)
if EXTRACT_DIR.exists():
shutil.rmtree(EXTRACT_DIR)
EXTRACT_DIR.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(local_zip, "r") as zf:
zf.extractall(str(EXTRACT_DIR))
contents = list(EXTRACT_DIR.iterdir())
predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR
return str(predictor_root)
def load_predictor() -> agmm.MultiModalPredictor:
predictor_root = _download_and_extract_predictor()
return agmm.MultiModalPredictor.load(predictor_root)
PREDICTOR = load_predictor()
# Helpers
def pil_preprocess_preview(pil_img: Image.Image, target_size=PREVIEW_SIZE) -> Image.Image:
return pil_img.convert("RGB").resize(target_size, Image.BILINEAR)
def run_predict_binary(predictor, pil_img: Image.Image):
tmpd = pathlib.Path(tempfile.mkdtemp())
tmp_path = tmpd / "input.png"
pil_img.save(tmp_path)
input_df = pd.DataFrame({"image": [str(tmp_path)]})
probs_df = predictor.predict_proba(input_df)
row = probs_df.iloc[0]
# Map to {label string: probability}
prob_dict = {
CLASS_LABELS[0]: float(row[0]),
CLASS_LABELS[1]: float(row[1]),
}
# Pick higher one
pred_label = CLASS_LABELS[int(row.idxmax())]
try:
shutil.rmtree(tmpd)
except Exception:
pass
return pred_label, prob_dict
# Gradio callback
def infer_and_display(image: Image.Image):
if image is None:
return None, None, "No image provided.", {}
# Resize large uploads automatically
bio = io.BytesIO()
image.save(bio, format="PNG")
if len(bio.getvalue()) > MAX_UPLOAD_BYTES:
max_side = 1024
image.thumbnail((max_side, max_side))
preview = pil_preprocess_preview(image, PREVIEW_SIZE)
pred_label, probs = run_predict_binary(PREDICTOR, image)
return image, preview, f"Prediction: {pred_label}", probs
# Build Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Stop Sign Detection — AutoGluon Predictor")
gr.Markdown(
"Upload an image or pick one of the examples. "
"The app shows the original and preprocessed images, and predicts whether the image **has a stop sign**."
)
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Upload an image", sources="upload")
run_btn = gr.Button("Run inference")
gr.Examples(EXAMPLE_IMAGES, inputs=[image_in], label="Example images", cache_examples=False)
with gr.Column(scale=1):
gr.Markdown("**Original image**")
orig_out = gr.Image(type="pil")
gr.Markdown("**Preprocessed image (preview)**")
pre_out = gr.Image(type="pil")
out_text = gr.Textbox(label="Prediction", interactive=False)
proba_label = gr.Label(label="Class probabilities")
run_btn.click(fn=infer_and_display, inputs=[image_in], outputs=[orig_out, pre_out, out_text, proba_label])
if __name__ == "__main__":
demo.launch()
|