Spaces:
Runtime error
Runtime error
| import pre_reqs | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| from mmcls.apis import init_model | |
| from mmcls.apis import inference_model_topk as inference_cls_model | |
| from mmdet.registry import VISUALIZERS | |
| # from mmcls.utils import register_all_modules as register_all_modules_cls | |
| from mmdet.apis import init_detector, inference_detector | |
| from mmdet.utils import register_all_modules as register_all_modules_det | |
| import pandas as pd | |
| from PIL import Image | |
| st.set_page_config(page_title="π₯ An Image Demo", page_icon="π₯", layout='wide') | |
| st.markdown("# π₯ An Image Demo") | |
| st.write( | |
| ":dog: Try uploading an image to get the possible categories, objects." | |
| ) | |
| st.sidebar.header("An Image Demo") | |
| my_upload = st.sidebar.file_uploader("Upload an image", type=["png", "jpg", "jpeg"]) | |
| col1, col2, col3 = st.columns(3) | |
| model_option = st.radio( | |
| "What\'s your inference model", | |
| ('cls', 'det')) | |
| parent_folder = './' | |
| topk = st.slider('Return top-k predictions', 1, 10, 3) | |
| # @st.cache_resource | |
| def _init_model(model_option): | |
| if model_option == 'cls': | |
| # init model | |
| model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py', | |
| 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth') | |
| visualizer = None | |
| elif model_option == 'det': | |
| # register_all_modules_det() | |
| model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py', | |
| 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth', device='cpu') | |
| visualizer = VISUALIZERS.build(model.cfg.visualizer) | |
| visualizer.dataset_meta = model.dataset_meta | |
| else: | |
| model = None | |
| visualizer = None | |
| return model, visualizer | |
| # @st.cache_data | |
| def _get_image(my_upload=my_upload): | |
| if my_upload is not None: | |
| img_file = my_upload | |
| else: | |
| img_file = parent_folder + "images/zebra.jpg" | |
| return Image.open(img_file).convert('RGB') | |
| # @st.cache_resource | |
| def _inference_model(img, model, visualizer, model_option): | |
| img = np.array(img) | |
| if model_option == 'cls': | |
| return_results = inference_cls_model(model, img, 10) | |
| vis_img = img | |
| elif model_option == 'det': | |
| vis_img = img.copy() | |
| results = inference_detector(model, img) | |
| # import pdb | |
| # pdb.set_trace() | |
| b, h, w = results.pred_instances.masks.shape | |
| vis_img = cv2.resize(vis_img, (w, h)) | |
| visualizer.add_datasample( | |
| name='result', | |
| image=vis_img, | |
| data_sample=results, | |
| draw_gt=False, | |
| show=False) | |
| vis_img = visualizer.get_image() | |
| cls_names = visualizer.dataset_meta['classes'] | |
| return_results = {'scores': results.pred_instances.scores[:10].numpy(), | |
| 'bboxes': results.pred_instances.bboxes[:10].numpy(), | |
| 'labels': [cls_names[x.item()] for x in results.pred_instances.labels[:10]] | |
| } | |
| return return_results, vis_img | |
| def plot_canvas(img, vis_img, results, model_option): | |
| col1.write("Original Image :camera:") | |
| col1.image(img) | |
| col2.write("Visualization:wrench:") | |
| col3.write("Metainfo:wrench:") | |
| if model_option == 'cls': | |
| col2.image(vis_img) | |
| df = pd.DataFrame({ | |
| 'category': results["pred_class"][:topk], | |
| 'probability': [f"{x:.2f}" for x in results["pred_score"]][:topk] | |
| }, index=None) | |
| col3.dataframe(df) | |
| elif model_option == 'det': | |
| # vis_idx = st.slider('Show a prediction', 1, 10, 3, disabled=True) | |
| col2.image(vis_img) | |
| df = pd.DataFrame({ | |
| 'category': results["labels"][:topk], | |
| 'probability': [f"{x:.2f}" for x in results["scores"]][:topk], | |
| 'box': [list(map(lambda t: f"{t:.2f}", list(x))) for x in results["bboxes"][:topk]] | |
| }, index=None) | |
| col3.dataframe(df) | |
| model, visualizer = _init_model(model_option) | |
| img = _get_image(my_upload) | |
| results, vis_img = _inference_model(img, model, visualizer, model_option) | |
| plot_canvas(img, vis_img, results, model_option) | |