pytorch-labs / segment-anything-fast

A batched offline inference oriented version of segment-anything
Apache License 2.0
1.21k stars 72 forks source link

Trouble running mask generation on M1 Max #129

Open mary-mark opened 2 months ago

mary-mark commented 2 months ago

Hi there,

I am trying to run SamAutomaticMaskGenerator using device = 'mps' but I am getting the following error: RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0". This is no longer allowed; the devices must match.

I also tried to use device = 'cpu' but the same error appears. Is it possible to run this on M1 Max, and if so how could I fix this error?

cpuhrsch commented 2 months ago

Hello @mary-mark - the code here has been specialized for GPUs, so I haven't tried it for mps yet. It's likely that you need to make some changes to have this work. Do you have the full stacktrace of the error?

mary-mark commented 2 months ago

Thanks for the reply @cpuhrsch. Here is the error: Screenshot 2024-09-04 at 7 00 40 PM

cpuhrsch commented 2 months ago

Oh interesting. Can you try setting pin_memory=False in automatic_mask_generator.py:291?

mary-mark commented 2 months ago

That seemed to work, but now I get the MPS float32 error:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

I've tried to cast as float32 in autonatic_mask_generator.py line 288 (in_points = torch.as_tensor(transformed_points.astype(np.float32)) , as it worked for me when using segment_anything, but I get the following error:

_RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_devicetype

Is there another place where I can specify to use float32?

cpuhrsch commented 2 months ago

Ok thanks for trying. Hm, could you post the full stack trace again please and send a branch with your changes?

mary-mark commented 2 months ago

This is the trace for float64 error: Screenshot 2024-09-10 at 3 09 02 PM

This is the trace for when I cast to float32, which is the version on this fork https://github.com/mary-mark/segment-anything-fast.git :

RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 masks = mask_generator.generate(image)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:170, in SamAutomaticMaskGenerator.generate(self, image)
    145 """
    146 Generates masks for the given image.
    147 
   (...)
    166          the mask, given in XYWH format.
    167 """
    169 # Generate masks
--> 170 mask_data = self._generate_masks(image)
    172 # Filter small disconnected regions and holes in masks
    173 if self.min_mask_region_area > 0:

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:213, in SamAutomaticMaskGenerator._generate_masks(self, image)
    211 data = MaskData()
    212 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 213     crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
    214     data.cat(crop_data)
    216 # Remove duplicate masks between crops

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:255, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
    253 for i in range(0, len(all_points), process_batch_size):
    254     some_points = all_points[i:i+process_batch_size]
--> 255     batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
    256     data.cat(batch_data)
    257 data["rles"] = mask_to_rle_pytorch_2(data["masks"])

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:298, in SamAutomaticMaskGenerator._process_batch(self, all_points, im_size, crop_box, orig_size)
    296 self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))]
    297 self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))]
--> 298 nt_masks, nt_iou_preds, _ = self.predictor.predict_torch(
    299     point_coords=nt_in_points,
    300     point_labels=nt_in_labels,
    301     multimask_output=True,
    302     return_logits=True,
    303 )
    305 data = MaskData()
    306 for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points):

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/predictor.py:230, in SamPredictor.predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
    223 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
    224     points=points,
    225     boxes=boxes,
    226     masks=mask_input,
    227 )
    229 # Predict masks
--> 230 low_res_masks, iou_predictions = self.model.mask_decoder(
    231     image_embeddings=self.features,
    232     image_pe=self.model.prompt_encoder.get_dense_pe(),
    233     sparse_prompt_embeddings=sparse_embeddings,
    234     dense_prompt_embeddings=dense_embeddings,
    235     multimask_output=multimask_output,
    236 )
    238 if low_res_masks.is_nested:
    239     masks = []

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/mask_decoder.py:99, in MaskDecoder.forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output)
     97     assert dense_prompt_embeddings.is_nested
     98     assert multimask_output
---> 99     masks, iou_pred = self.predict_masks_nested(
    100         image_embeddings=image_embeddings,
    101         image_pe=image_pe,
    102         sparse_prompt_embeddings=sparse_prompt_embeddings.to(self_dtype),
    103         dense_prompt_embeddings=dense_prompt_embeddings.to(self_dtype),
    104     )
    105     return masks, iou_pred
    106 else:

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/mask_decoder.py:188, in MaskDecoder.predict_masks_nested(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings)
    185 h, w = src.shape[-2:]
    187 # Run the transformer
--> 188 hs, src = self.transformer(src, pos_src, tokens)
    189 iou_token_out = hs[..., 0, :]
    190 mask_tokens_out = hs[..., 1 : (1 + self.num_mask_tokens), :]

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:91, in TwoWayTransformer.forward(self, image_embedding, image_pe, point_embedding)
     89 # Apply transformer blocks and final layernorm
     90 for layer in self.layers:
---> 91     queries, keys = layer(
     92         queries=queries,
     93         keys=keys,
     94         query_pe=point_embedding,
     95         key_pe=image_pe,
     96     )
     98 # Apply the final attention layer from the points to the image
     99 q = queries + point_embedding

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:155, in TwoWayAttentionBlock.forward(self, queries, keys, query_pe, key_pe)
    150 def forward(
    151     self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    152 ) -> Tuple[Tensor, Tensor]:
    153     # Self attention block
    154     if self.skip_first_layer_pe:
--> 155         queries = self.self_attn(q=queries, k=queries, v=queries)
    156     else:
    157         q = queries + query_pe

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
   1734     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735 else:
-> 1736     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
   1742 # If we don't have any hooks, we want to skip the rest of the logic in
   1743 # this function, and just call forward.
   1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1745         or _global_backward_pre_hooks or _global_backward_hooks
   1746         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747     return forward_call(*args, **kwargs)
   1749 result = None
   1750 called_always_called_hooks = set()

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:227, in Attention.forward(self, q, k, v)
    224 v = self._separate_heads(v, self.num_heads)
    226 # Attention
--> 227 out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    229 # Get output
    230 out = self._recombine_heads(out)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py:310, in NestedTensor.__torch_function__(cls, func, types, args, kwargs)
    308 with maybe_enable_thunkify():
    309     try:
--> 310         return jagged_torch_function(func, *args, **kwargs)
    311     except NotImplementedError:
    312         pass

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/ops.py:315, in jagged_torch_function(func, *args, **kwargs)
    311 def jagged_torch_function(func, *args, **kwargs):
    312     # SDPA has special kernels that handle nested tensors.
    313     # Dispatch to the correct implementation here
    314     if func is torch._C._nn.scaled_dot_product_attention:
--> 315         return jagged_scaled_dot_product_attention(*args, **kwargs)
    317     if func.__name__ == "apply_":
    318         func(args[0]._values, *args[1:], **kwargs)

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py:688, in jagged_scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
    678 def jagged_scaled_dot_product_attention(
    679     query: torch.Tensor,
    680     key: torch.Tensor,
   (...)
    686     enable_gqa=False,
    687 ):
--> 688     query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
    689     _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
    690     # for mypy, ugh

File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py:660, in _autocast(query, key, value, attn_mask)
    658 device_type = query.device.type
    659 # meta device is not supported by autocast, so break early for it
--> 660 if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
    661     return query, key, value, attn_mask
    663 def cvt(x):

RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type
cpuhrsch commented 2 months ago

Ok, I see. Right, so this will require NestedTensors to work on M1 Max. You'd have to turn off more and more performance features of segment-anything-fast to get this to work on M1 Max. I'm not sure this project will be useful to you after turning all of that off. We really focus on GPUs within this repository.