Spaces:
Runtime error
Runtime error
| import pre_reqs | |
| import glob | |
| import os.path | |
| import cv2 | |
| import numpy as np | |
| import streamlit as st | |
| from mmcls.apis import init_model | |
| from mmcls.apis import inference_model_topk as inference_cls_model | |
| from mmdet.registry import VISUALIZERS | |
| # from mmcls.utils import register_all_modules as register_all_modules_cls | |
| from mmdet.apis import init_detector, inference_detector | |
| from mmdet.utils import register_all_modules as register_all_modules_det | |
| import pandas as pd | |
| from PIL import Image | |
| register_all_modules_det() | |
| st.set_page_config(page_title="π· A Folder Demo", page_icon="π·", layout='wide') | |
| st.markdown("# π· A Folder Demo") | |
| st.write( | |
| ":dog: Try uploading multi images to get the possible categories, objects." | |
| ) | |
| st.sidebar.header("A Folder Demo") | |
| my_upload = st.sidebar.file_uploader("Upload multi images", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
| col1, col2 = st.columns(2) | |
| parent_folder = './' | |
| # @st.cache_resource | |
| def _init_model_return_results(imgs): | |
| cls_model = init_model(parent_folder + 'configs/resnet/resnet50_8xb32_in1k.py', | |
| 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth') | |
| imgs = [np.array(x) for x in imgs] | |
| results = {} | |
| for idx, img in enumerate(imgs): | |
| return_results = inference_cls_model(cls_model, img, 5) | |
| results[idx] = set(np.array(return_results["pred_class"])[return_results["pred_score"] > 0.35]) | |
| det_model = init_detector(parent_folder + 'configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py', | |
| 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth', | |
| device='cpu') | |
| dataset_meta = det_model.dataset_meta | |
| return_results = inference_detector(det_model, img) | |
| cls_names = dataset_meta['classes'] | |
| scores = return_results.pred_instances.scores.numpy()[:10] | |
| labels = np.array([cls_names[x.item()] for x in return_results.pred_instances.labels[:10]]) | |
| results[idx] |= set(labels[scores > 0.35]) | |
| class2idx = {} | |
| for k, v in results.items(): | |
| for sub_v in v: | |
| class2idx[sub_v] = class2idx.get(sub_v, []) + [k] | |
| return results, class2idx | |
| # @st.cache_data | |
| def _get_image(my_upload): | |
| if len(my_upload): | |
| img_files = my_upload | |
| if isinstance(img_files, str): | |
| img_files = [img_files] | |
| file_names = [os.path.basename(x.name).split('.')[0][:8] for x in img_files] | |
| else: | |
| img_files = glob.glob(parent_folder + "/images/*.jpg") + glob.glob(parent_folder + "/images/*.png") | |
| file_names = [os.path.basename(x).split('.')[0][:8] for x in img_files] | |
| return [Image.open(img_file).convert('RGB') for img_file in img_files], file_names | |
| def plot_canvas(imgs, results, file_names, class2idx): | |
| col1.write("Original Images :camera:") | |
| col2.write("Filtered Images :wrench:") | |
| tabs = col1.tabs(file_names) | |
| for idx, tab in enumerate(tabs): | |
| tab.image(imgs[idx], width=400) | |
| all_classes = set() | |
| for x in results.values(): | |
| all_classes |= x | |
| all_classes = list(all_classes) | |
| options = st.multiselect( | |
| 'Select the classes:', | |
| all_classes) | |
| select_idx = set(range(len(file_names))) | |
| for idx, op in enumerate(options): | |
| select_idx &= set(class2idx[op]) | |
| select_idx = np.array(list(select_idx)) | |
| if len(select_idx): | |
| names = np.array(file_names)[select_idx].tolist() | |
| tabs = col2.tabs(names) | |
| for idx, tab in enumerate(tabs): | |
| tabs[idx].image(imgs[select_idx[idx]], width=400) | |
| tabs[idx].write(', '.join(results[select_idx[idx]])) | |
| imgs, file_names = _get_image(my_upload) | |
| results, class2idx = _init_model_return_results(imgs) | |
| plot_canvas(imgs, results, file_names, class2idx) | |