Kenny Santanu commited on
Commit
61aae43
Β·
1 Parent(s): 47dfb37

Add initial implementation of image segmentation app with SAM2 model and necessary files

Browse files
Files changed (4) hide show
  1. .gitignore +207 -0
  2. LICENSE +21 -0
  3. app.py +172 -0
  4. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[codz]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py.cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+ #poetry.toml
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
114
+ # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
115
+ #pdm.lock
116
+ #pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # pixi
121
+ # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
122
+ #pixi.lock
123
+ # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
124
+ # in the .venv directory. It is recommended not to include this directory in version control.
125
+ .pixi
126
+
127
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
128
+ __pypackages__/
129
+
130
+ # Celery stuff
131
+ celerybeat-schedule
132
+ celerybeat.pid
133
+
134
+ # SageMath parsed files
135
+ *.sage.py
136
+
137
+ # Environments
138
+ .env
139
+ .envrc
140
+ .venv
141
+ env/
142
+ venv/
143
+ ENV/
144
+ env.bak/
145
+ venv.bak/
146
+
147
+ # Spyder project settings
148
+ .spyderproject
149
+ .spyproject
150
+
151
+ # Rope project settings
152
+ .ropeproject
153
+
154
+ # mkdocs documentation
155
+ /site
156
+
157
+ # mypy
158
+ .mypy_cache/
159
+ .dmypy.json
160
+ dmypy.json
161
+
162
+ # Pyre type checker
163
+ .pyre/
164
+
165
+ # pytype static type analyzer
166
+ .pytype/
167
+
168
+ # Cython debug symbols
169
+ cython_debug/
170
+
171
+ # PyCharm
172
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
173
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
174
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
175
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
176
+ #.idea/
177
+
178
+ # Abstra
179
+ # Abstra is an AI-powered process automation framework.
180
+ # Ignore directories containing user credentials, local state, and settings.
181
+ # Learn more at https://abstra.io/docs
182
+ .abstra/
183
+
184
+ # Visual Studio Code
185
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
186
+ # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
187
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
188
+ # you could uncomment the following to ignore the entire vscode folder
189
+ # .vscode/
190
+
191
+ # Ruff stuff:
192
+ .ruff_cache/
193
+
194
+ # PyPI configuration file
195
+ .pypirc
196
+
197
+ # Cursor
198
+ # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
199
+ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
200
+ # refer to https://docs.cursor.com/context/ignore-files
201
+ .cursorignore
202
+ .cursorindexingignore
203
+
204
+ # Marimo
205
+ marimo/_static/
206
+ marimo/_lsp/
207
+ __marimo__/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Kenny Santanu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ from ultralytics import SAM
6
+
7
+ class ImageSegmentationApp:
8
+ def __init__(self) -> None:
9
+ """Initialize the segmentation app and load the SAM2 model with fallback."""
10
+ try:
11
+ # Attempt to load the SAM2 model weights
12
+ self.model = SAM("sam2.1_t.pt")
13
+ self.model_available = True # Model loaded successfully
14
+ except Exception as e:
15
+ # If loading fails, set model as unavailable and print error
16
+ print(f"Failed to load SAM2 model: {e}")
17
+ self.model = None
18
+ self.model_available = False
19
+
20
+ def process_segmentation(
21
+ self,
22
+ image_editor: dict,
23
+ replacement_image: Image.Image
24
+ ) -> list[Image.Image | None] | None:
25
+ """
26
+ Process the segmentation and replacement using the drawn mask and SAM2 model.
27
+ Returns [drawn_mask, sam_mask, result_image, markdown_message].
28
+ """
29
+ # Check if both images are provided
30
+ if image_editor["background"] is None or replacement_image is None:
31
+ return [None, None, None, "**❌ Error:** Please upload both images."]
32
+ try:
33
+ # Extract the original image and the user-drawn mask
34
+ original_image = image_editor["background"]
35
+ drawn_mask = image_editor["layers"][0]
36
+
37
+ # Use the alpha channel of the mask as the binary mask
38
+ drawn_mask = drawn_mask.split()[-1]
39
+ drawn_mask_np = np.array(drawn_mask)
40
+
41
+ # Find contours in the mask to determine segmentation points
42
+ points = []
43
+ contours, _ = cv2.findContours(drawn_mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
44
+ for contour in contours:
45
+ M = cv2.moments(contour)
46
+ if M["m00"] != 0:
47
+ # Use centroid of contour as a point
48
+ cx = float(M["m10"] / M["m00"])
49
+ cy = float(M["m01"] / M["m00"])
50
+ points.append([cx, cy])
51
+ else:
52
+ # Fallback: use the first point if the area is zero
53
+ x, y = contour[0][0]
54
+ points.append([float(x), float(y)])
55
+
56
+ # If no points are found, return original image and a message indicating no mask was drawn
57
+ if not points:
58
+ return [None, None, original_image, "**❌ Error:** No mask drawn. Please draw a mask on the original image."]
59
+
60
+ # If the SAM2 model is unavailable, use the drawn mask directly
61
+ if not self.model_available or not self.model:
62
+ sam_mask = drawn_mask
63
+ model_message = "**⚠️ Warning:** SAM2 model unavailable, using drawn mask as mask."
64
+ else:
65
+ # Run the SAM2 model to refine the mask
66
+ results = self.model(
67
+ source=original_image,
68
+ points=[points],
69
+ )
70
+ # Extract the mask from the model output
71
+ result_numpy_arr = results[0].masks.data.numpy()
72
+ sam_mask_arr = np.squeeze(result_numpy_arr)
73
+ sam_mask_arr = (sam_mask_arr * 255).astype(np.uint8) # Convert bool to uint8
74
+ sam_mask = Image.fromarray(sam_mask_arr)
75
+ model_message = "**βœ… Success:** Segmentation completed with SAM2."
76
+
77
+ # Resize the replacement image to match the original image size
78
+ replacement_image = replacement_image.resize(original_image.size)
79
+ # Composite the replacement image onto the original using the mask
80
+ result_image = Image.composite(replacement_image, original_image, sam_mask)
81
+
82
+ return [drawn_mask, sam_mask, result_image, model_message]
83
+ except Exception as e:
84
+ # Catch and report any errors during segmentation
85
+ print(f"Segmentation error: {e}")
86
+ return [None, None, None, f"**❌ Error:** Segmentation error: {e}"]
87
+
88
+ def create_interface(self) -> gr.Blocks:
89
+ """Create and return the Gradio interface"""
90
+ with gr.Blocks(title="SAM2 Image Segmentation & Replacement", theme=gr.themes.Soft(), css=".center-status-message {text-align: center;}") as demo:
91
+ # App title and instructions
92
+ gr.Markdown(
93
+ f"""
94
+ # 🎨 SAM2 Image Segmentation & Replacement
95
+
96
+ Upload an original image and a replacement image, then draw a rough mask on the original image.
97
+
98
+ **Instructions:**
99
+ 1. Upload your original image
100
+ 2. Upload your replacement image
101
+ 3. Draw a mask on the original image by painting over the area you want to replace
102
+ 4. Click "Process Segmentation" to see the result
103
+ """
104
+ )
105
+ gr.Markdown("### πŸ“Έ Upload Images")
106
+ with gr.Row():
107
+ with gr.Column():
108
+ # ImageMask for original image and mask drawing
109
+ image_editor = gr.ImageMask(
110
+ label="Original Image",
111
+ type="pil",
112
+ height=400
113
+ )
114
+ with gr.Column():
115
+ # Upload for replacement image
116
+ replacement_image = gr.Image(
117
+ label="Replacement Image",
118
+ type="pil",
119
+ height=400
120
+ )
121
+ with gr.Row():
122
+ # Button to trigger segmentation
123
+ process_btn = gr.Button("πŸš€ Process Segmentation", variant="primary", size="lg")
124
+ with gr.Row():
125
+ # Status message for feedback
126
+ status_message = gr.Markdown(value="", elem_id="status_message", elem_classes=["center-status-message"])
127
+ with gr.Row():
128
+ # Display the drawn mask, SAM2 mask, and result image
129
+ drawn_mask = gr.Image(
130
+ label="Drawn Mask",
131
+ type="pil",
132
+ height=400
133
+ )
134
+ result_mask = gr.Image(
135
+ label="SAM2 Mask",
136
+ type="pil",
137
+ height=400
138
+ )
139
+ result_image = gr.Image(
140
+ label="Result",
141
+ type="pil",
142
+ height=400
143
+ )
144
+ with gr.Row():
145
+ # Display copywrite information
146
+ gr.Markdown(
147
+ value="Β© 2025 Kenny Santanu. All rights reserved.",
148
+ elem_classes=["center-status-message"]
149
+ )
150
+
151
+ # Connect button click to segmentation function
152
+ process_btn.click(
153
+ fn=self.process_segmentation,
154
+ inputs=[image_editor, replacement_image],
155
+ outputs=[drawn_mask, result_mask, result_image, status_message]
156
+ )
157
+ return demo
158
+
159
+ def main() -> None:
160
+ """Main function to run the application"""
161
+ # Instantiate the app
162
+ app = ImageSegmentationApp()
163
+ # Create the Gradio interface
164
+ demo = app.create_interface()
165
+ # Launch the interface (web server)
166
+ demo.launch(
167
+ show_api=False
168
+ )
169
+
170
+ # Run the app if this script is executed directly
171
+ if __name__ == "__main__":
172
+ main()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio>=5.37.0
2
+ opencv-python>=4.12.0.88
3
+ ultralytics>=8.3.167