royerlab / napari-segment-anything

Segment Anything Model (SAM) native Qt UI
Apache License 2.0
189 stars 17 forks source link

Support for float32 #27

Open phisanti opened 3 months ago

phisanti commented 3 months ago

I am working on a macbook M2. I have thew following error when working with the plugin:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1048, in SignalInstance._run_emit_loop(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, args=(False,))
   1046     with Signal._emitting(self):
   1047         # allow receiver to query sender with Signal.current_emitter()
-> 1048         self._run_emit_loop_inner()
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self._run_emit_loop_inner = <bound method SignalInstance._run_emit_loop_immediate of <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>>
   1049 except RecursionError as e:

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1067, in SignalInstance._run_emit_loop_immediate(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>)
   1066 self._caller = caller
-> 1067 caller.cb(args)
        args = (False,)
        caller = <WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_weak_callback.py:453, in WeakMethod.cb(self=<WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>, args=())
    452     args = args[: self._max_args]
--> 453 func(obj, *self._args, *args, **self._kwargs)
        obj = <Container ()>
        func = <function SAMWidget._on_auto_run at 0x30c869b20>
        args = ()
        self = <WeakMethod on napari_segment_anything._widget.SAMWidget._on_auto_run>
        self._args = ()
        self._kwargs = {}

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/napari_segment_anything/_widget.py:192, in SAMWidget._on_auto_run(self=<Container ()>)
    191 mask_gen = SamAutomaticMaskGenerator(self._sam)
--> 192 preds = mask_gen.generate(self._image)
        mask_gen = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
        self._image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <Container ()>
    194 labels = self._labels_layer.data

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args=(<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, <class 'numpy.ndarray'> (2400, 2400, 3) uint8), **kwargs={})
    114 with ctx_factory():
--> 115     return func(*args, **kwargs)
        func = <function SamAutomaticMaskGenerator.generate at 0x30c8691c0>
        args = (<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>, <class 'numpy.ndarray'> (2400, 2400, 3) uint8)
        kwargs = {}

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:163, in SamAutomaticMaskGenerator.generate(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8)
    162 # Generate masks
--> 163 mask_data = self._generate_masks(image)
        image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    165 # Filter small disconnected regions and holes in masks

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:206, in SamAutomaticMaskGenerator._generate_masks(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8)
    205 for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 206     crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
        orig_size = (2400, 2400)
        crop_box = [0, 0, 2400, 2400]
        layer_idx = 0
        image = <class 'numpy.ndarray'> (2400, 2400, 3) uint8
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    207     data.cat(crop_data)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:245, in SamAutomaticMaskGenerator._process_crop(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, image=<class 'numpy.ndarray'> (2400, 2400, 3) uint8, crop_box=[0, 0, 2400, 2400], crop_layer_idx=0, orig_size=(2400, 2400))
    244 for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 245     batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
        crop_box = [0, 0, 2400, 2400]
        cropped_im_size = (2400, 2400)
        points = <class 'numpy.ndarray'> (64, 2) float64
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
        orig_size = (2400, 2400)
    246     data.cat(batch_data)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/segment_anything/automatic_mask_generator.py:277, in SamAutomaticMaskGenerator._process_batch(self=<segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object>, points=<class 'numpy.ndarray'> (64, 2) float64, im_size=(2400, 2400), crop_box=[0, 0, 2400, 2400], orig_size=(2400, 2400))
    276 transformed_points = self.predictor.transform.apply_coords(points, im_size)
--> 277 in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
        transformed_points = <class 'numpy.ndarray'> (64, 2) float64
        self.predictor = <segment_anything.predictor.SamPredictor object at 0x34318d890>
        self = <segment_anything.automatic_mask_generator.SamAutomaticMaskGenerator object at 0x34318d410>
    278 in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

The above exception was the direct cause of the following exception:

EmitLoopError                             Traceback (most recent call last)
File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/magicgui/widgets/bases/_value_widget.py:71, in ValueWidget._on_value_change(self=PushButton(value=False, annotation=None, name=''), value=False)
     69 if value is self.null_value and not self._nullable:
     70     return
---> 71 self.changed.emit(value)
        value = False
        self.changed = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self = PushButton(value=False, annotation=None, name='')

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1025, in SignalInstance.emit(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, check_nargs=False, check_types=False, *args=(False,))
   1021     from ._group import EmissionInfo
   1023     SignalInstance._debug_hook(EmissionInfo(self, args))
-> 1025 self._run_emit_loop(args)
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        args = (False,)

File ~/miniconda3/envs/img_analysis/lib/python3.11/site-packages/psygnal/_signal.py:1055, in SignalInstance._run_emit_loop(self=<SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>, args=(False,))
   1050     raise RecursionError(
   1051         f"RecursionError when "
   1052         f"emitting signal {self.name!r} with args {args}"
   1053     ) from e
   1054 except Exception as e:
-> 1055     raise EmitLoopError(
        self = <SignalInstance 'changed' on PushButton(value=False, annotation=None, name='')>
        self._args = ()
        self._caller = None
   1056         cb=self._caller, args=self._args, exc=e, signal=self
   1057     ) from e
   1058 finally:
   1059     self._emit_queue.clear()

EmitLoopError: 
While emitting signal 'magicgui.widgets.PushButton.changed', an error occurred in callback 'napari_segment_anything._widget.SAMWidget._on_auto_run'.
The args passed to the callback were: (False,)
This is not a bug in psygnal.  See 'TypeError' above for details.

I think the error might be solve if there would be support for float32. Please, let me know if this would be possible.

JoOkuma commented 3 months ago

Hi @phisanti,

This is a limitation of segmentation-anything from Meta; there's an open PR that enables Apple silicon support, facebookresearch/segment-anything#122 but it hasn't been merged.

You can install this version from git through

pip install git+https://github.com/DrSleep/segment-anything@mps-support

Let me know if it works