Kenny Santanu
Add initial implementation of image segmentation app with SAM2 model and necessary files
61aae43
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from ultralytics import SAM | |
| class ImageSegmentationApp: | |
| def __init__(self) -> None: | |
| """Initialize the segmentation app and load the SAM2 model with fallback.""" | |
| try: | |
| # Attempt to load the SAM2 model weights | |
| self.model = SAM("sam2.1_t.pt") | |
| self.model_available = True # Model loaded successfully | |
| except Exception as e: | |
| # If loading fails, set model as unavailable and print error | |
| print(f"Failed to load SAM2 model: {e}") | |
| self.model = None | |
| self.model_available = False | |
| def process_segmentation( | |
| self, | |
| image_editor: dict, | |
| replacement_image: Image.Image | |
| ) -> list[Image.Image | None] | None: | |
| """ | |
| Process the segmentation and replacement using the drawn mask and SAM2 model. | |
| Returns [drawn_mask, sam_mask, result_image, markdown_message]. | |
| """ | |
| # Check if both images are provided | |
| if image_editor["background"] is None or replacement_image is None: | |
| return [None, None, None, "**β Error:** Please upload both images."] | |
| try: | |
| # Extract the original image and the user-drawn mask | |
| original_image = image_editor["background"] | |
| drawn_mask = image_editor["layers"][0] | |
| # Use the alpha channel of the mask as the binary mask | |
| drawn_mask = drawn_mask.split()[-1] | |
| drawn_mask_np = np.array(drawn_mask) | |
| # Find contours in the mask to determine segmentation points | |
| points = [] | |
| contours, _ = cv2.findContours(drawn_mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| for contour in contours: | |
| M = cv2.moments(contour) | |
| if M["m00"] != 0: | |
| # Use centroid of contour as a point | |
| cx = float(M["m10"] / M["m00"]) | |
| cy = float(M["m01"] / M["m00"]) | |
| points.append([cx, cy]) | |
| else: | |
| # Fallback: use the first point if the area is zero | |
| x, y = contour[0][0] | |
| points.append([float(x), float(y)]) | |
| # If no points are found, return original image and a message indicating no mask was drawn | |
| if not points: | |
| return [None, None, original_image, "**β Error:** No mask drawn. Please draw a mask on the original image."] | |
| # If the SAM2 model is unavailable, use the drawn mask directly | |
| if not self.model_available or not self.model: | |
| sam_mask = drawn_mask | |
| model_message = "**β οΈ Warning:** SAM2 model unavailable, using drawn mask as mask." | |
| else: | |
| # Run the SAM2 model to refine the mask | |
| results = self.model( | |
| source=original_image, | |
| points=[points], | |
| ) | |
| # Extract the mask from the model output | |
| result_numpy_arr = results[0].masks.data.numpy() | |
| sam_mask_arr = np.squeeze(result_numpy_arr) | |
| sam_mask_arr = (sam_mask_arr * 255).astype(np.uint8) # Convert bool to uint8 | |
| sam_mask = Image.fromarray(sam_mask_arr) | |
| model_message = "**β Success:** Segmentation completed with SAM2." | |
| # Resize the replacement image to match the original image size | |
| replacement_image = replacement_image.resize(original_image.size) | |
| # Composite the replacement image onto the original using the mask | |
| result_image = Image.composite(replacement_image, original_image, sam_mask) | |
| return [drawn_mask, sam_mask, result_image, model_message] | |
| except Exception as e: | |
| # Catch and report any errors during segmentation | |
| print(f"Segmentation error: {e}") | |
| return [None, None, None, f"**β Error:** Segmentation error: {e}"] | |
| def create_interface(self) -> gr.Blocks: | |
| """Create and return the Gradio interface""" | |
| with gr.Blocks(title="SAM2 Image Segmentation & Replacement", theme=gr.themes.Soft(), css=".center-status-message {text-align: center;}") as demo: | |
| # App title and instructions | |
| gr.Markdown( | |
| f""" | |
| # π¨ SAM2 Image Segmentation & Replacement | |
| Upload an original image and a replacement image, then draw a rough mask on the original image. | |
| **Instructions:** | |
| 1. Upload your original image | |
| 2. Upload your replacement image | |
| 3. Draw a mask on the original image by painting over the area you want to replace | |
| 4. Click "Process Segmentation" to see the result | |
| """ | |
| ) | |
| gr.Markdown("### πΈ Upload Images") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # ImageMask for original image and mask drawing | |
| image_editor = gr.ImageMask( | |
| label="Original Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Column(): | |
| # Upload for replacement image | |
| replacement_image = gr.Image( | |
| label="Replacement Image", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| # Button to trigger segmentation | |
| process_btn = gr.Button("π Process Segmentation", variant="primary", size="lg") | |
| with gr.Row(): | |
| # Status message for feedback | |
| status_message = gr.Markdown(value="", elem_id="status_message", elem_classes=["center-status-message"]) | |
| with gr.Row(): | |
| # Display the drawn mask, SAM2 mask, and result image | |
| drawn_mask = gr.Image( | |
| label="Drawn Mask", | |
| type="pil", | |
| height=400 | |
| ) | |
| result_mask = gr.Image( | |
| label="SAM2 Mask", | |
| type="pil", | |
| height=400 | |
| ) | |
| result_image = gr.Image( | |
| label="Result", | |
| type="pil", | |
| height=400 | |
| ) | |
| with gr.Row(): | |
| # Display copywrite information | |
| gr.Markdown( | |
| value="Β© 2025 Kenny Santanu. All rights reserved.", | |
| elem_classes=["center-status-message"] | |
| ) | |
| # Connect button click to segmentation function | |
| process_btn.click( | |
| fn=self.process_segmentation, | |
| inputs=[image_editor, replacement_image], | |
| outputs=[drawn_mask, result_mask, result_image, status_message] | |
| ) | |
| return demo | |
| def main() -> None: | |
| """Main function to run the application""" | |
| # Instantiate the app | |
| app = ImageSegmentationApp() | |
| # Create the Gradio interface | |
| demo = app.create_interface() | |
| # Launch the interface (web server) | |
| demo.launch( | |
| show_api=False | |
| ) | |
| # Run the app if this script is executed directly | |
| if __name__ == "__main__": | |
| main() | |