robustsam / RobustSAM

RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight)
https://robustsam.github.io/
MIT License
152 stars 15 forks source link

AutomaticMaskGenerator not working #1

Closed imneonizer closed 2 months ago

imneonizer commented 2 months ago

Here is the driver code:

sam = sam_model_registry["vit_l"](opt=None, checkpoint="robustsam_checkpoint.pth").to(device=device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.5,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,
)

masks = mask_generator.generate(image) # image is a numpy array with shape (2048,2048,3)

Error:

  0%|          | 0/1 [00:01<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_617/4150095890.py in <module>
     33 
     34 for (x1,y1,x2,y2) in tqdm.tqdm(calculate_slice_bboxes(image.shape[0], image.shape[1], 2048, 2048)):
---> 35     reconstructed_image[y1:y2, x1:x2] = process_window(image[y1:y2, x1:x2])

/tmp/ipykernel_617/1729186085.py in process_window(image)
      2     # generate masks
      3     final_mask = image.copy()
----> 4     masks = mask_generator.generate(image)
      5     masks = [x for x in masks if x['area'] > 1000]
      6     # masks = [x for x in masks if x['area'] > 100]

/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in generate(self, image)
    161 
    162         # Generate masks
--> 163         mask_data = self._generate_masks(image)
    164 
    165         # Filter small disconnected regions and holes in masks

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _generate_masks(self, image)
    205         data = MaskData()
    206         for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 207             crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
    208             data.cat(crop_data)
    209 

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _process_crop(self, image, crop_box, crop_layer_idx, orig_size)
    246         count = 0
    247         for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 248             batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
    249             # print('Second: ', mask_logits.shape)
    250 

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _process_batch(self, points, im_size, crop_box, orig_size)
    285         in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
    286         in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
--> 287         masks, iou_preds, _ = self.predictor.predict_torch(
    288             in_points[:, None, :],
    289             in_labels[:, None],

/RobustSAM/robust_segment_anything/predictor.py in predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
    228 
    229         # Predict masks
--> 230         low_res_masks, iou_predictions = self.model.mask_decoder(
    231             image_embeddings=self.features,
    232             image_pe=self.model.prompt_encoder.get_dense_pe(),

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() missing 1 required positional argument: 'encoder_features'

On further investigation, I found this https://github.com/robustsam/RobustSAM/blob/main/robust_segment_anything/predictor.py#L230-L236 which is not passing the encoder_features which is a required param for mask decoder here: https://github.com/robustsam/RobustSAM/blob/main/robust_segment_anything/modeling/mask_decoder.py#L98C9-L117 Interestingly #TODO is mentioned for that param aswell.

robustsam commented 2 months ago

Thanks for your interest! The predictor.py has just been updated, please try again. You can achieve automatic mask generation by code below:

mask_generator = SamAutomaticMaskGenerator(model)
image = cv2.imread('demo_images/blur.jpg')
masks = mask_generator.generate(image)

for i in range(len(masks)): # save the mask results
    mask = ((masks[i]['segmentation'])*255).astype(np.uint8)
    cv2.imwrite('mask_{}.jpg'.format(str(i).zfill(2)), mask)
imneonizer commented 2 months ago

Thanks the issue is resolved now, however I found out it is not very much useful for my use case i.e., Field Delineation on Sentinel-2 Images with 10m resolution, those images are too much blurry. I wonder if you would be releasing code to finetune the model for such use cases.