Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import io | |
| import zipfile | |
| import shutil | |
| import tempfile | |
| import pathlib | |
| import pandas as pd | |
| from PIL import Image | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import autogluon.multimodal as agmm | |
| MODEL_REPO_ID = "its-zion-18/sign-image-autogluon-predictor" | |
| ZIP_FILENAME = "autogluon_image_predictor_dir.zip" | |
| CACHE_DIR = pathlib.Path("hf_assets") | |
| EXTRACT_DIR = CACHE_DIR / "predictor_native" | |
| PREVIEW_SIZE = (224, 224) | |
| MAX_UPLOAD_BYTES = 20 * 1024 * 1024 # Allow up to 20 MB now | |
| ex1_path = 'IMG_0059.png' | |
| ex2_path = 'IMG_0064.png' | |
| ex3_path = 'IMG_8689.jpg' | |
| ex1 = Image.open(ex1_path) | |
| ex2 = Image.open(ex2_path) | |
| ex3 = Image.open(ex3_path) | |
| EXAMPLE_IMAGES = [ex1, ex2, ex3] | |
| CLASS_LABELS = {0: "Does not have stop sign", 1: "Has stop sign"} | |
| # Download & load predictor | |
| def _download_and_extract_predictor() -> str: | |
| CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| local_zip = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=ZIP_FILENAME, | |
| repo_type="model", | |
| local_dir=str(CACHE_DIR), | |
| local_dir_use_symlinks=False, | |
| ) | |
| if EXTRACT_DIR.exists(): | |
| shutil.rmtree(EXTRACT_DIR) | |
| EXTRACT_DIR.mkdir(parents=True, exist_ok=True) | |
| with zipfile.ZipFile(local_zip, "r") as zf: | |
| zf.extractall(str(EXTRACT_DIR)) | |
| contents = list(EXTRACT_DIR.iterdir()) | |
| predictor_root = contents[0] if (len(contents) == 1 and contents[0].is_dir()) else EXTRACT_DIR | |
| return str(predictor_root) | |
| def load_predictor() -> agmm.MultiModalPredictor: | |
| predictor_root = _download_and_extract_predictor() | |
| return agmm.MultiModalPredictor.load(predictor_root) | |
| PREDICTOR = load_predictor() | |
| # Helpers | |
| def pil_preprocess_preview(pil_img: Image.Image, target_size=PREVIEW_SIZE) -> Image.Image: | |
| return pil_img.convert("RGB").resize(target_size, Image.BILINEAR) | |
| def run_predict_binary(predictor, pil_img: Image.Image): | |
| tmpd = pathlib.Path(tempfile.mkdtemp()) | |
| tmp_path = tmpd / "input.png" | |
| pil_img.save(tmp_path) | |
| input_df = pd.DataFrame({"image": [str(tmp_path)]}) | |
| probs_df = predictor.predict_proba(input_df) | |
| row = probs_df.iloc[0] | |
| # Map to {label string: probability} | |
| prob_dict = { | |
| CLASS_LABELS[0]: float(row[0]), | |
| CLASS_LABELS[1]: float(row[1]), | |
| } | |
| # Pick higher one | |
| pred_label = CLASS_LABELS[int(row.idxmax())] | |
| try: | |
| shutil.rmtree(tmpd) | |
| except Exception: | |
| pass | |
| return pred_label, prob_dict | |
| # Gradio callback | |
| def infer_and_display(image: Image.Image): | |
| if image is None: | |
| return None, None, "No image provided.", {} | |
| # Resize large uploads automatically | |
| bio = io.BytesIO() | |
| image.save(bio, format="PNG") | |
| if len(bio.getvalue()) > MAX_UPLOAD_BYTES: | |
| max_side = 1024 | |
| image.thumbnail((max_side, max_side)) | |
| preview = pil_preprocess_preview(image, PREVIEW_SIZE) | |
| pred_label, probs = run_predict_binary(PREDICTOR, image) | |
| return image, preview, f"Prediction: {pred_label}", probs | |
| # Build Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Stop Sign Detection β AutoGluon Predictor") | |
| gr.Markdown( | |
| "Upload an image or pick one of the examples. " | |
| "The app shows the original and preprocessed images, and predicts whether the image **has a stop sign**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_in = gr.Image(type="pil", label="Upload an image", sources="upload") | |
| run_btn = gr.Button("Run inference") | |
| gr.Examples(EXAMPLE_IMAGES, inputs=[image_in], label="Example images", cache_examples=False) | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Original image**") | |
| orig_out = gr.Image(type="pil") | |
| gr.Markdown("**Preprocessed image (preview)**") | |
| pre_out = gr.Image(type="pil") | |
| out_text = gr.Textbox(label="Prediction", interactive=False) | |
| proba_label = gr.Label(label="Class probabilities") | |
| run_btn.click(fn=infer_and_display, inputs=[image_in], outputs=[orig_out, pre_out, out_text, proba_label]) | |
| if __name__ == "__main__": | |
| demo.launch() | |