Open mary-mark opened 2 months ago
Hello @mary-mark - the code here has been specialized for GPUs, so I haven't tried it for mps yet. It's likely that you need to make some changes to have this work. Do you have the full stacktrace of the error?
Thanks for the reply @cpuhrsch. Here is the error:
Oh interesting. Can you try setting pin_memory=False
in automatic_mask_generator.py:291?
That seemed to work, but now I get the MPS float32 error:
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
I've tried to cast as float32 in autonatic_mask_generator.py line 288 (in_points = torch.as_tensor(transformed_points.astype(np.float32)) , as it worked for me when using segment_anything, but I get the following error:
_RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_devicetype
Is there another place where I can specify to use float32?
Ok thanks for trying. Hm, could you post the full stack trace again please and send a branch with your changes?
This is the trace for float64 error:
This is the trace for when I cast to float32, which is the version on this fork https://github.com/mary-mark/segment-anything-fast.git :
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 masks = mask_generator.generate(image)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:170, in SamAutomaticMaskGenerator.generate(self, image)
145 """
146 Generates masks for the given image.
147
(...)
166 the mask, given in XYWH format.
167 """
169 # Generate masks
--> 170 mask_data = self._generate_masks(image)
172 # Filter small disconnected regions and holes in masks
173 if self.min_mask_region_area > 0:
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:213, in SamAutomaticMaskGenerator._generate_masks(self, image)
211 data = MaskData()
212 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 213 crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
214 data.cat(crop_data)
216 # Remove duplicate masks between crops
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:255, in SamAutomaticMaskGenerator._process_crop(self, image, crop_box, crop_layer_idx, orig_size)
253 for i in range(0, len(all_points), process_batch_size):
254 some_points = all_points[i:i+process_batch_size]
--> 255 batch_data = self._process_batch(some_points, cropped_im_size, crop_box, orig_size)
256 data.cat(batch_data)
257 data["rles"] = mask_to_rle_pytorch_2(data["masks"])
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/automatic_mask_generator.py:298, in SamAutomaticMaskGenerator._process_batch(self, all_points, im_size, crop_box, orig_size)
296 self.predictor.input_sizes = [self.predictor.input_size for _ in range(len(nt_in_points))]
297 self.predictor.original_sizes = [self.predictor.original_size for _ in range(len(nt_in_points))]
--> 298 nt_masks, nt_iou_preds, _ = self.predictor.predict_torch(
299 point_coords=nt_in_points,
300 point_labels=nt_in_labels,
301 multimask_output=True,
302 return_logits=True,
303 )
305 data = MaskData()
306 for masks, iou_preds, points in zip(nt_masks.unbind(), nt_iou_preds.unbind(), all_points):
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
--> 116 return func(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/predictor.py:230, in SamPredictor.predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
223 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
224 points=points,
225 boxes=boxes,
226 masks=mask_input,
227 )
229 # Predict masks
--> 230 low_res_masks, iou_predictions = self.model.mask_decoder(
231 image_embeddings=self.features,
232 image_pe=self.model.prompt_encoder.get_dense_pe(),
233 sparse_prompt_embeddings=sparse_embeddings,
234 dense_prompt_embeddings=dense_embeddings,
235 multimask_output=multimask_output,
236 )
238 if low_res_masks.is_nested:
239 masks = []
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/mask_decoder.py:99, in MaskDecoder.forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output)
97 assert dense_prompt_embeddings.is_nested
98 assert multimask_output
---> 99 masks, iou_pred = self.predict_masks_nested(
100 image_embeddings=image_embeddings,
101 image_pe=image_pe,
102 sparse_prompt_embeddings=sparse_prompt_embeddings.to(self_dtype),
103 dense_prompt_embeddings=dense_prompt_embeddings.to(self_dtype),
104 )
105 return masks, iou_pred
106 else:
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/mask_decoder.py:188, in MaskDecoder.predict_masks_nested(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings)
185 h, w = src.shape[-2:]
187 # Run the transformer
--> 188 hs, src = self.transformer(src, pos_src, tokens)
189 iou_token_out = hs[..., 0, :]
190 mask_tokens_out = hs[..., 1 : (1 + self.num_mask_tokens), :]
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:91, in TwoWayTransformer.forward(self, image_embedding, image_pe, point_embedding)
89 # Apply transformer blocks and final layernorm
90 for layer in self.layers:
---> 91 queries, keys = layer(
92 queries=queries,
93 keys=keys,
94 query_pe=point_embedding,
95 key_pe=image_pe,
96 )
98 # Apply the final attention layer from the points to the image
99 q = queries + point_embedding
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:155, in TwoWayAttentionBlock.forward(self, queries, keys, query_pe, key_pe)
150 def forward(
151 self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
152 ) -> Tuple[Tensor, Tensor]:
153 # Self attention block
154 if self.skip_first_layer_pe:
--> 155 queries = self.self_attn(q=queries, k=queries, v=queries)
156 else:
157 q = queries + query_pe
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/segment_anything_fast/modeling/transformer.py:227, in Attention.forward(self, q, k, v)
224 v = self._separate_heads(v, self.num_heads)
226 # Attention
--> 227 out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
229 # Get output
230 out = self._recombine_heads(out)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/nested_tensor.py:310, in NestedTensor.__torch_function__(cls, func, types, args, kwargs)
308 with maybe_enable_thunkify():
309 try:
--> 310 return jagged_torch_function(func, *args, **kwargs)
311 except NotImplementedError:
312 pass
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/ops.py:315, in jagged_torch_function(func, *args, **kwargs)
311 def jagged_torch_function(func, *args, **kwargs):
312 # SDPA has special kernels that handle nested tensors.
313 # Dispatch to the correct implementation here
314 if func is torch._C._nn.scaled_dot_product_attention:
--> 315 return jagged_scaled_dot_product_attention(*args, **kwargs)
317 if func.__name__ == "apply_":
318 func(args[0]._values, *args[1:], **kwargs)
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py:688, in jagged_scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa)
678 def jagged_scaled_dot_product_attention(
679 query: torch.Tensor,
680 key: torch.Tensor,
(...)
686 enable_gqa=False,
687 ):
--> 688 query, key, value, attn_mask = _autocast(query, key, value, attn_mask)
689 _validate_sdpa_input(query, key, value, attn_mask, dropout_p, is_causal, scale)
690 # for mypy, ugh
File ~/miniconda3/envs/seg_any_fast/lib/python3.12/site-packages/torch/nested/_internal/sdpa.py:660, in _autocast(query, key, value, attn_mask)
658 device_type = query.device.type
659 # meta device is not supported by autocast, so break early for it
--> 660 if _is_computing_meta_flops(query) or not torch.is_autocast_enabled(device_type):
661 return query, key, value, attn_mask
663 def cvt(x):
RuntimeError: unknown device type for autocast in get_autocast_dispatch_key_from_device_type
Ok, I see. Right, so this will require NestedTensors to work on M1 Max. You'd have to turn off more and more performance features of segment-anything-fast to get this to work on M1 Max. I'm not sure this project will be useful to you after turning all of that off. We really focus on GPUs within this repository.
Hi there,
I am trying to run SamAutomaticMaskGenerator using device = 'mps' but I am getting the following error: RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "mps:0". This is no longer allowed; the devices must match.
I also tried to use device = 'cpu' but the same error appears. Is it possible to run this on M1 Max, and if so how could I fix this error?