Spaces:
Runtime error
Runtime error
| import cv2 | |
| import numpy as np | |
| from pil import Image | |
| import streamlit as st | |
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| # most of this code has been obtained from Daature's prediction script | |
| # https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py | |
| st.set_option('deprecation.showfileUploaderEncoding', False) | |
| def load_model(): | |
| return tf.saved_model.load('./saved_model') | |
| def load_label_map(label_map_path): | |
| """ | |
| Reads label map in the format of .pbtxt and parse into dictionary | |
| Args: | |
| label_map_path: the file path to the label_map | |
| Returns: | |
| dictionary with the format of {label_index: {'id': label_index, 'name': label_name}} | |
| """ | |
| label_map = {} | |
| with open(label_map_path, "r") as label_file: | |
| for line in label_file: | |
| if "id" in line: | |
| label_index = int(line.split(":")[-1]) | |
| label_name = next(label_file).split(":")[-1].strip().strip('"') | |
| label_map[label_index] = {"id": label_index, "name": label_name} | |
| return label_map | |
| def predict_class(image, model): | |
| image = tf.cast(image, tf.float32) | |
| image = tf.image.resize(image, [150, 150]) | |
| image = np.expand_dims(image, axis = 0) | |
| return model.predict(image) | |
| def plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape): | |
| for idx, each_bbox in enumerate(bboxes): | |
| color = color_map[classes[idx]] | |
| ## Draw bounding box | |
| cv2.rectangle( | |
| image_origi, | |
| (int(each_bbox[1] * origi_shape[1]), | |
| int(each_bbox[0] * origi_shape[0]),), | |
| (int(each_bbox[3] * origi_shape[1]), | |
| int(each_bbox[2] * origi_shape[0]),), | |
| color, | |
| 2, | |
| ) | |
| ## Draw label background | |
| cv2.rectangle( | |
| image_origi, | |
| (int(each_bbox[1] * origi_shape[1]), | |
| int(each_bbox[2] * origi_shape[0]),), | |
| (int(each_bbox[3] * origi_shape[1]), | |
| int(each_bbox[2] * origi_shape[0] + 15),), | |
| color, | |
| -1, | |
| ) | |
| ## Insert label class & score | |
| cv2.putText( | |
| image_origi, | |
| "Class: {}, Score: {}".format( | |
| str(category_index[classes[idx]]["name"]), | |
| str(round(scores[idx], 2)), | |
| ), | |
| (int(each_bbox[1] * origi_shape[1]), | |
| int(each_bbox[2] * origi_shape[0] + 10),), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.3, | |
| (0, 0, 0), | |
| 1, | |
| cv2.LINE_AA, | |
| ) | |
| return image_origi | |
| # Webpage code starts here | |
| st.title('Banana ripeness detection 🍌') | |
| st.text('made by XXX') #TODO change with your name | |
| st.markdown('## Find out if a banana is too ripe!') | |
| with st.spinner('Model is being loaded...'): | |
| model = load_model() | |
| # ask user to upload an image | |
| file = st.file_uploader("Upload an image of a banana", type=["jpg", "png"]) | |
| if file is None: | |
| st.text('Waiting for upload...') | |
| else: | |
| st.text('Running inference...') | |
| # open image | |
| test_image = Image.open(file).convert("RGB") | |
| origi_shape = np.asarray(test_image).shape | |
| # resize image to default shape | |
| default_shape = 320 | |
| image_resized = np.array(test_image.resize((default_shape, default_shape))) | |
| ## Load color map | |
| category_index = load_label_map("./label_map.pbtxt") | |
| # color of each label. check label_map.pbtxt to check the index for each class | |
| # TODO Add more colors if there are more classes | |
| color_map = { | |
| 2: [255, 0, 0], # overripe -> red | |
| 1: [0, 255, 0] # ripe -> green | |
| } | |
| ## The model input needs to be a tensor | |
| input_tensor = tf.convert_to_tensor(image_resized) | |
| ## The model expects a batch of images, so add an axis with `tf.newaxis`. | |
| input_tensor = input_tensor[tf.newaxis, ...] | |
| ## Feed image into model and obtain output | |
| detections_output = model(input_tensor) | |
| num_detections = int(detections_output.pop("num_detections")) | |
| detections = {key: value[0, :num_detections].numpy() for key, value in detections_output.items()} | |
| detections["num_detections"] = num_detections | |
| ## Filter out predictions below threshold | |
| # if threshold is higher, there will be fewer predictions | |
| # TODO change this number to see how the predictions change | |
| confidence_threshold = 0.5 | |
| indexes = np.where(detections["detection_scores"] > confidence_threshold) | |
| ## Extract predicted bounding boxes | |
| bboxes = detections["detection_boxes"][indexes] | |
| # there are no predicted boxes | |
| if len(bboxes) == 0: | |
| st.error('No boxes predicted') | |
| # there are predicted boxes | |
| else: | |
| st.success('Boxes predicted') | |
| classes = detections["detection_classes"][indexes].astype(np.int64) | |
| scores = detections["detection_scores"][indexes] | |
| # plot boxes and labels on image | |
| image_origi = np.array(Image.fromarray(image_resized).resize((origi_shape[1], origi_shape[0]))) | |
| image_origi = plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape) | |
| # show image in web page | |
| st.image(Image.fromarray(image_origi), caption="Image with predictions", width=400) | |
| st.markdown("### Predicted boxes") | |
| for idx in range(len((bboxes))): | |
| st.markdown(f"* Class: {str(category_index[classes[idx]]['name'])}, confidence score: {str(round(scores[idx], 2))}") | |