File size: 7,457 Bytes
61aae43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from ultralytics import SAM

class ImageSegmentationApp:
    def __init__(self) -> None:
        """Initialize the segmentation app and load the SAM2 model with fallback."""
        try:
            # Attempt to load the SAM2 model weights
            self.model = SAM("sam2.1_t.pt")
            self.model_available = True  # Model loaded successfully
        except Exception as e:
            # If loading fails, set model as unavailable and print error
            print(f"Failed to load SAM2 model: {e}")
            self.model = None
            self.model_available = False

    def process_segmentation(
        self, 
        image_editor: dict, 
        replacement_image: Image.Image
    ) -> list[Image.Image | None] | None:
        """
        Process the segmentation and replacement using the drawn mask and SAM2 model.
        Returns [drawn_mask, sam_mask, result_image, markdown_message].
        """
        # Check if both images are provided
        if image_editor["background"] is None or replacement_image is None:
            return [None, None, None, "**❌ Error:** Please upload both images."]
        try:
            # Extract the original image and the user-drawn mask
            original_image = image_editor["background"]
            drawn_mask = image_editor["layers"][0]
            
            # Use the alpha channel of the mask as the binary mask
            drawn_mask = drawn_mask.split()[-1]
            drawn_mask_np = np.array(drawn_mask)

            # Find contours in the mask to determine segmentation points
            points = []
            contours, _ = cv2.findContours(drawn_mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            for contour in contours:
                M = cv2.moments(contour)
                if M["m00"] != 0:
                    # Use centroid of contour as a point
                    cx = float(M["m10"] / M["m00"])
                    cy = float(M["m01"] / M["m00"])
                    points.append([cx, cy])
                else:
                    # Fallback: use the first point if the area is zero
                    x, y = contour[0][0]
                    points.append([float(x), float(y)])

            # If no points are found, return original image and a message indicating no mask was drawn         
            if not points:
                return [None, None, original_image, "**❌ Error:** No mask drawn. Please draw a mask on the original image."]

            # If the SAM2 model is unavailable, use the drawn mask directly
            if not self.model_available or not self.model:
                sam_mask = drawn_mask
                model_message = "**⚠️ Warning:** SAM2 model unavailable, using drawn mask as mask."
            else:
                # Run the SAM2 model to refine the mask
                results = self.model(
                    source=original_image,
                    points=[points],
                )
                # Extract the mask from the model output
                result_numpy_arr = results[0].masks.data.numpy()
                sam_mask_arr = np.squeeze(result_numpy_arr)
                sam_mask_arr = (sam_mask_arr * 255).astype(np.uint8)  # Convert bool to uint8
                sam_mask = Image.fromarray(sam_mask_arr)
                model_message = "**✅ Success:** Segmentation completed with SAM2."

            # Resize the replacement image to match the original image size
            replacement_image = replacement_image.resize(original_image.size)
            # Composite the replacement image onto the original using the mask
            result_image = Image.composite(replacement_image, original_image, sam_mask)

            return [drawn_mask, sam_mask, result_image, model_message]
        except Exception as e:
            # Catch and report any errors during segmentation
            print(f"Segmentation error: {e}")
            return [None, None, None, f"**❌ Error:** Segmentation error: {e}"]

    def create_interface(self) -> gr.Blocks:
        """Create and return the Gradio interface"""
        with gr.Blocks(title="SAM2 Image Segmentation & Replacement", theme=gr.themes.Soft(), css=".center-status-message {text-align: center;}") as demo:
            # App title and instructions
            gr.Markdown(
                f"""
                # 🎨 SAM2 Image Segmentation & Replacement
                
                Upload an original image and a replacement image, then draw a rough mask on the original image.
                
                **Instructions:**
                1. Upload your original image
                2. Upload your replacement image  
                3. Draw a mask on the original image by painting over the area you want to replace
                4. Click "Process Segmentation" to see the result
                """
            )
            gr.Markdown("### 📸 Upload Images")
            with gr.Row():
                with gr.Column():
                    # ImageMask for original image and mask drawing
                    image_editor = gr.ImageMask(
                        label="Original Image",
                        type="pil",
                        height=400
                    )
                with gr.Column():
                    # Upload for replacement image
                    replacement_image = gr.Image(
                        label="Replacement Image",
                        type="pil",
                        height=400
                    )
            with gr.Row():
                # Button to trigger segmentation
                process_btn = gr.Button("🚀 Process Segmentation", variant="primary", size="lg")
            with gr.Row():
                # Status message for feedback
                status_message = gr.Markdown(value="", elem_id="status_message", elem_classes=["center-status-message"])
            with gr.Row():
                # Display the drawn mask, SAM2 mask, and result image
                drawn_mask = gr.Image(
                    label="Drawn Mask", 
                    type="pil",
                    height=400
                )
                result_mask = gr.Image(
                    label="SAM2 Mask",
                    type="pil",
                    height=400
                )
                result_image = gr.Image(
                    label="Result", 
                    type="pil",
                    height=400
                )
            with gr.Row():
                # Display copywrite information
                gr.Markdown(
                    value="© 2025 Kenny Santanu. All rights reserved.",
                    elem_classes=["center-status-message"]
                )

            # Connect button click to segmentation function
            process_btn.click(
                fn=self.process_segmentation,
                inputs=[image_editor, replacement_image],
                outputs=[drawn_mask, result_mask, result_image, status_message]
            )
        return demo

def main() -> None:
    """Main function to run the application"""
    # Instantiate the app
    app = ImageSegmentationApp()
    # Create the Gradio interface
    demo = app.create_interface()
    # Launch the interface (web server)
    demo.launch(
        show_api=False
    )

# Run the app if this script is executed directly
if __name__ == "__main__":
    main()