prithivMLmods commited on
Commit
28be05f
·
verified ·
1 Parent(s): 452e070

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -21
app.py CHANGED
@@ -7,16 +7,16 @@ import numpy as np
7
  import torch
8
  import matplotlib
9
  import matplotlib.pyplot as plt
10
- from PIL import Image
11
  from typing import Iterable
12
  from gradio.themes import Soft
13
  from gradio.themes.utils import colors, fonts, sizes
14
  from transformers import (
15
  Sam3Model, Sam3Processor,
16
- Sam3VideoModel, Sam3VideoProcessor
 
17
  )
18
 
19
- # --- THEME CONFIGURATION ---
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
@@ -79,21 +79,25 @@ class CustomBlueTheme(Soft):
79
 
80
  app_theme = CustomBlueTheme()
81
 
82
- # --- GLOBAL MODEL LOADING ---
83
  device = "cuda" if torch.cuda.is_available() else "cpu"
84
  print(f"🖥️ Using compute device: {device}")
85
 
86
  print("⏳ Loading SAM3 Models permanently into memory...")
87
 
88
  try:
89
- # 1. Load Image Segmentation Model
90
- print(" ... Loading Image Model")
91
  IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3").to(device)
92
  IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3")
93
 
94
- # 2. Load Video Segmentation Model
95
- # Using bfloat16 for video to optimize VRAM usage while keeping speed
 
 
 
 
96
  print(" ... Loading Video Model")
 
97
  VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
98
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
99
 
@@ -102,8 +106,10 @@ try:
102
  except Exception as e:
103
  print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
104
  IMG_MODEL = None
105
- VID_MODEL = None
106
  IMG_PROCESSOR = None
 
 
 
107
  VID_PROCESSOR = None
108
 
109
 
@@ -152,21 +158,31 @@ def apply_mask_overlay(base_image, mask_data, opacity=0.5):
152
 
153
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
154
 
155
-
156
- # --- GPU INFERENCE FUNCTIONS ---
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  @spaces.GPU
159
  def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
160
  if IMG_MODEL is None or IMG_PROCESSOR is None:
161
- raise gr.Error("Models failed to load on startup. Check logs.")
162
 
163
  if source_img is None or not text_query:
164
  raise gr.Error("Please provide an image and a text prompt.")
165
 
166
  try:
167
  pil_image = source_img.convert("RGB")
168
-
169
- # Models are already on device, just move inputs
170
  model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
171
 
172
  with torch.no_grad():
@@ -179,7 +195,6 @@ def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
179
  target_sizes=model_inputs.get("original_sizes").tolist()
180
  )[0]
181
 
182
- # Use AnnotatedImage format
183
  annotation_list = []
184
  raw_masks = processed_results['masks'].cpu().numpy()
185
  raw_scores = processed_results['scores'].cpu().numpy()
@@ -193,6 +208,50 @@ def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
193
  except Exception as e:
194
  raise gr.Error(f"Error during image processing: {e}")
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def calc_timeout_duration(vid_file, *args):
197
  return args[-1] if args else 60
198
 
@@ -219,7 +278,6 @@ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
219
  counter += 1
220
  video_cap.release()
221
 
222
- # VID_MODEL is already on device in bfloat16
223
  session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
224
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
225
 
@@ -246,16 +304,15 @@ def run_video_segmentation(source_vid, text_query, frame_limit, time_limit):
246
  except Exception as e:
247
  return None, f"Error during video processing: {str(e)}"
248
 
249
- # --- GUI ---
250
  custom_css="""
251
  #col-container { margin: 0 auto; max-width: 1100px; }
252
  #main-title h1 { font-size: 2.1em !important; }
253
  """
254
 
255
- with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
256
  with gr.Column(elem_id="col-container"):
257
  gr.Markdown("# **SAM3: Segment Anything Model 3**", elem_id="main-title")
258
- gr.Markdown("Segment objects in image or video using **SAM3** (Segment Anything Model 3) with text prompts.")
259
 
260
  with gr.Tabs():
261
  with gr.Tab("Image Segmentation"):
@@ -287,7 +344,7 @@ with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
287
  inputs=[image_input, txt_prompt_img, conf_slider],
288
  outputs=[image_result]
289
  )
290
-
291
  with gr.Tab("Video Segmentation"):
292
  with gr.Row():
293
  with gr.Column():
@@ -320,6 +377,31 @@ with gr.Blocks(css=custom_css, theme=app_theme) as main_interface:
320
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
321
  outputs=[video_result, process_status]
322
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
  if __name__ == "__main__":
325
- main_interface.launch(ssr_mode=False, mcp_server=True, show_error=True)
 
7
  import torch
8
  import matplotlib
9
  import matplotlib.pyplot as plt
10
+ from PIL import Image, ImageDraw
11
  from typing import Iterable
12
  from gradio.themes import Soft
13
  from gradio.themes.utils import colors, fonts, sizes
14
  from transformers import (
15
  Sam3Model, Sam3Processor,
16
+ Sam3VideoModel, Sam3VideoProcessor,
17
+ Sam3TrackerModel, Sam3TrackerProcessor
18
  )
19
 
 
20
  colors.steel_blue = colors.Color(
21
  name="steel_blue",
22
  c50="#EBF3F8",
 
79
 
80
  app_theme = CustomBlueTheme()
81
 
 
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
83
  print(f"🖥️ Using compute device: {device}")
84
 
85
  print("⏳ Loading SAM3 Models permanently into memory...")
86
 
87
  try:
88
+ # 1. Load Image Segmentation Model (Text)
89
+ print(" ... Loading Image Text Model")
90
  IMG_MODEL = Sam3Model.from_pretrained("facebook/sam3").to(device)
91
  IMG_PROCESSOR = Sam3Processor.from_pretrained("facebook/sam3")
92
 
93
+ # 2. Load Image Tracker Model (Click)
94
+ print(" ... Loading Image Tracker Model")
95
+ TRK_MODEL = Sam3TrackerModel.from_pretrained("facebook/sam3").to(device)
96
+ TRK_PROCESSOR = Sam3TrackerProcessor.from_pretrained("facebook/sam3")
97
+
98
+ # 3. Load Video Segmentation Model
99
  print(" ... Loading Video Model")
100
+ # Using bfloat16 for video to optimize VRAM
101
  VID_MODEL = Sam3VideoModel.from_pretrained("facebook/sam3").to(device, dtype=torch.bfloat16)
102
  VID_PROCESSOR = Sam3VideoProcessor.from_pretrained("facebook/sam3")
103
 
 
106
  except Exception as e:
107
  print(f"❌ CRITICAL ERROR LOADING MODELS: {e}")
108
  IMG_MODEL = None
 
109
  IMG_PROCESSOR = None
110
+ TRK_MODEL = None
111
+ TRK_PROCESSOR = None
112
+ VID_MODEL = None
113
  VID_PROCESSOR = None
114
 
115
 
 
158
 
159
  return Image.alpha_composite(base_image, composite_layer).convert("RGB")
160
 
161
+ def draw_points_on_image(image, points):
162
+ """Draws red dots on the image to indicate click locations."""
163
+ if isinstance(image, np.ndarray):
164
+ image = Image.fromarray(image)
165
+
166
+ draw_img = image.copy()
167
+ draw = ImageDraw.Draw(draw_img)
168
+
169
+ for pt in points:
170
+ x, y = pt
171
+ r = 6 # Radius of point
172
+ draw.ellipse((x-r, y-r, x+r, y+r), fill="red", outline="white", width=2)
173
+
174
+ return draw_img
175
 
176
  @spaces.GPU
177
  def run_image_segmentation(source_img, text_query, conf_thresh=0.5):
178
  if IMG_MODEL is None or IMG_PROCESSOR is None:
179
+ raise gr.Error("Models failed to load on startup.")
180
 
181
  if source_img is None or not text_query:
182
  raise gr.Error("Please provide an image and a text prompt.")
183
 
184
  try:
185
  pil_image = source_img.convert("RGB")
 
 
186
  model_inputs = IMG_PROCESSOR(images=pil_image, text=text_query, return_tensors="pt").to(device)
187
 
188
  with torch.no_grad():
 
195
  target_sizes=model_inputs.get("original_sizes").tolist()
196
  )[0]
197
 
 
198
  annotation_list = []
199
  raw_masks = processed_results['masks'].cpu().numpy()
200
  raw_scores = processed_results['scores'].cpu().numpy()
 
208
  except Exception as e:
209
  raise gr.Error(f"Error during image processing: {e}")
210
 
211
+ @spaces.GPU
212
+ def run_image_click_gpu(input_image, x, y, points_state, labels_state):
213
+ if TRK_MODEL is None or TRK_PROCESSOR is None:
214
+ raise gr.Error("Tracker Model failed to load.")
215
+
216
+ if input_image is None: return input_image, [], []
217
+ if points_state is None: points_state = []; labels_state = []
218
+
219
+ # Append new point
220
+ points_state.append([x, y])
221
+ labels_state.append(1) # 1 indicates a positive click (foreground)
222
+
223
+ try:
224
+ # Prepare inputs format: [Batch, Point_Group, Point_Idx, Coord]
225
+ input_points = [[points_state]]
226
+ input_labels = [[labels_state]]
227
+
228
+ inputs = TRK_PROCESSOR(images=input_image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device)
229
+
230
+ with torch.no_grad():
231
+ # multimask_output=True usually helps with ambiguity, but let's default to best mask for simplicity here
232
+ outputs = TRK_MODEL(**inputs, multimask_output=False)
233
+
234
+ # Post process
235
+ masks = TRK_PROCESSOR.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"], binarize=True)[0]
236
+
237
+ # Overlay mask
238
+ # masks shape is [1, 1, H, W] for single object tracking
239
+ final_img = apply_mask_overlay(input_image, masks[0])
240
+
241
+ # Draw the visual points on top
242
+ final_img = draw_points_on_image(final_img, points_state)
243
+
244
+ return final_img, points_state, labels_state
245
+
246
+ except Exception as e:
247
+ print(f"Tracker Error: {e}")
248
+ return input_image, points_state, labels_state
249
+
250
+ def image_click_handler(image, evt: gr.SelectData, points_state, labels_state):
251
+ # Wrapper to handle the Gradio select event
252
+ x, y = evt.index
253
+ return run_image_click_gpu(image, x, y, points_state, labels_state)
254
+
255
  def calc_timeout_duration(vid_file, *args):
256
  return args[-1] if args else 60
257
 
 
278
  counter += 1
279
  video_cap.release()
280
 
 
281
  session = VID_PROCESSOR.init_video_session(video=video_frames, inference_device=device, dtype=torch.bfloat16)
282
  session = VID_PROCESSOR.add_text_prompt(inference_session=session, text=text_query)
283
 
 
304
  except Exception as e:
305
  return None, f"Error during video processing: {str(e)}"
306
 
 
307
  custom_css="""
308
  #col-container { margin: 0 auto; max-width: 1100px; }
309
  #main-title h1 { font-size: 2.1em !important; }
310
  """
311
 
312
+ with gr.Blocks(css=custom_css, theme=app_theme) as demo:
313
  with gr.Column(elem_id="col-container"):
314
  gr.Markdown("# **SAM3: Segment Anything Model 3**", elem_id="main-title")
315
+ gr.Markdown("Segment objects in image or video using **SAM3** with Text Prompts or Interactive Clicks.")
316
 
317
  with gr.Tabs():
318
  with gr.Tab("Image Segmentation"):
 
344
  inputs=[image_input, txt_prompt_img, conf_slider],
345
  outputs=[image_result]
346
  )
347
+
348
  with gr.Tab("Video Segmentation"):
349
  with gr.Row():
350
  with gr.Column():
 
377
  inputs=[video_input, txt_prompt_vid, frame_limiter, time_limiter],
378
  outputs=[video_result, process_status]
379
  )
380
+
381
+ with gr.Tab("Image Click Segmentation"):
382
+ with gr.Row():
383
+ with gr.Column(scale=1):
384
+ img_click_input = gr.Image(type="pil", label="Input Image (Click points)", interactive=True, height=450)
385
+
386
+ with gr.Row():
387
+ img_click_clear = gr.Button("Clear Points & Reset", variant="secondary")
388
+
389
+ st_click_points = gr.State([])
390
+ st_click_labels = gr.State([])
391
+
392
+ with gr.Column(scale=1):
393
+ img_click_output = gr.Image(type="pil", label="Result Preview", height=450, interactive=False)
394
+
395
+ img_click_input.select(
396
+ image_click_handler,
397
+ inputs=[img_click_input, st_click_points, st_click_labels],
398
+ outputs=[img_click_output, st_click_points, st_click_labels]
399
+ )
400
+
401
+ img_click_clear.click(
402
+ lambda: (None, [], []),
403
+ outputs=[img_click_output, st_click_points, st_click_labels]
404
+ )
405
 
406
  if __name__ == "__main__":
407
+ demo.launch(ssr_mode=False, mcp_server=True, show_error=True)