computational-cell-analytics / micro-sam

Segment Anything for Microscopy
https://computational-cell-analytics.github.io/micro-sam/
MIT License
327 stars 40 forks source link

Struggling to use finetuned model in napari #675

Closed Cryaaa closed 1 month ago

Cryaaa commented 1 month ago

First of all I must say that you made it super easy to train a finetuned model and I've had success running it from a notebook, which was also the way I finetuned the model. Right now I cannot find a way to use this self tuned model in the plugin and was wondering if I just missed the part n the documentation or if you could add a paragraph on how to do this? I want to have it available in the plugin so the rest of my lab could use it on their data with the image series annotation plugin. Thanks in advance!

constantinpape commented 1 month ago

Hi @Cryaaa , in the plugin the finetuned model can be used by entering the filepath in the custom weights path field in the Embedding Settings. See the screenshot below:

image

In the CLI it can be passed with the option --checkpoint.

Let us know if you run into any other issues with the finetuning.

Cryaaa commented 1 month ago

Thanks for the quick reply, I somehow assumed it should shot up in the models drop down for some reason.

I am now running into another problem though. I'm specifying the path to the best model checkpoint created when running the finetuning notebook but when I try to calculate the embedding I get this error:

---------------------------------------------------------------------------
UnpicklingError                           Traceback (most recent call last)
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:613, in create_worker.<locals>.reraise(e=UnpicklingError('NEWOBJ class argument must be a type, not NoneType'))
    612 def reraise(e):
--> 613     raise e
        e = UnpicklingError('NEWOBJ class argument must be a type, not NoneType')

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:175, in WorkerBase.run(self=<napari._qt.qthreading.FunctionWorker object>)
    173     warnings.filterwarnings("always")
    174     warnings.showwarning = lambda *w: self.warned.emit(w)
--> 175     result = self.work()
        self = <napari._qt.qthreading.FunctionWorker object at 0x000002C5A24B01F0>
    176 if isinstance(result, Exception):
    177     if isinstance(result, RuntimeError):
    178         # The Worker object has likely been deleted.
    179         # A deleted wrapped C/C++ object may result in a runtime
    180         # error that will cause segfault if we try to do much other
    181         # than simply notify the user.

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:354, in FunctionWorker.work(self=<napari._qt.qthreading.FunctionWorker object>)
    353 def work(self) -> _R:
--> 354     return self._func(*self._args, **self._kwargs)
        self._func = <function EmbeddingWidget.__call__.<locals>.compute_image_embedding at 0x000002C5A262D760>
        self = <napari._qt.qthreading.FunctionWorker object at 0x000002C5A24B01F0>
        self._args = ()
        self._kwargs = {}

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\sam_annotator\_widgets.py:1094, in EmbeddingWidget.__call__.<locals>.compute_image_embedding()
   1091     pbar_signals.pbar_total.emit(total)
   1092     pbar_signals.pbar_description.emit(description)
-> 1094 state.initialize_predictor(
        state = AnnotatorState(image_embeddings=None, predictor=None, image_shape=(496, 676), embedding_path=None, data_signature=None, amg=None, amg_state=None, decoder=None, current_track_id=None, lineage=None, committed_lineages=None, widgets={'embeddings': <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>, 'prompts': <Container ()>, 'segment': <FunctionGui segment(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), batched: bool = False) -> None>, 'autosegment': <micro_sam.sam_annotator._widgets.AutoSegmentWidget object at 0x000002C5A2213BE0>, 'commit': <FunctionGui commit(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), layer: str = 'current_object', preserve_committed: bool = True, commit_path: Optional[pathlib.Path] = None) -> None>, 'clear': <FunctionGui clear(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>})) -> None>}, z_range=None)
        image_data = array([[35.69919, 35.82557, ..., 35.77521, 35.5687 ],
       [35.54334, 35.84689, ..., 35.92925, 35.62788],
       ...,
       [38.29271, 38.98734, ..., 37.39376, 37.19536],
       [38.99753, 38.71956, ..., 36.66956, 37.22879]])
        self.model_type = 'vit_h'
        save_path = None
        ndim = 2
        self = <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>
        self.device = 'auto'
        self.custom_weights = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
        tile_shape = None
        halo = None
        self.prefer_decoder = True
        pbar_signals = <micro_sam.sam_annotator._widgets.PBarSignals object at 0x000002C5A24B05E0>
   1095     image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
   1096     device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
   1097     prefer_decoder=self.prefer_decoder, pbar_init=pbar_init,
   1098     pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
   1099 )
   1100 pbar_signals.pbar_stop.emit()

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\sam_annotator\_state.py:87, in AnnotatorState.initialize_predictor(self=AnnotatorState(image_embeddings=None, predictor=...t 0x000002C5A242F100>})) -> None>}, z_range=None), image_data=array([[35.69919, 35.82557, ..., 35.77521, 35.56...  [38.99753, 38.71956, ..., 36.66956, 37.22879]]), model_type='vit_h', ndim=2, save_path=None, device='auto', predictor=None, decoder=None, checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', tile_shape=None, halo=None, precompute_amg_state=False, prefer_decoder=True, pbar_init=<function EmbeddingWidget.__call__.<locals>.compute_image_embedding.<locals>.pbar_init>, pbar_update=<function EmbeddingWidget.__call__.<locals>.compute_image_embedding.<locals>.<lambda>>)
     85 # Initialize the model if necessary.
     86 if predictor is None:
---> 87     self.predictor, state = util.get_sam_model(
        self.predictor = None
        self = AnnotatorState(image_embeddings=None, predictor=None, image_shape=(496, 676), embedding_path=None, data_signature=None, amg=None, amg_state=None, decoder=None, current_track_id=None, lineage=None, committed_lineages=None, widgets={'embeddings': <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>, 'prompts': <Container ()>, 'segment': <FunctionGui segment(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), batched: bool = False) -> None>, 'autosegment': <micro_sam.sam_annotator._widgets.AutoSegmentWidget object at 0x000002C5A2213BE0>, 'commit': <FunctionGui commit(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), layer: str = 'current_object', preserve_committed: bool = True, commit_path: Optional[pathlib.Path] = None) -> None>, 'clear': <FunctionGui clear(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>})) -> None>}, z_range=None)
        util = <module 'micro_sam.util' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\site-packages\\micro_sam\\util.py'>
        device = 'auto'
        model_type = 'vit_h'
        checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
     88         device=device, model_type=model_type,
     89         checkpoint_path=checkpoint_path, return_state=True
     90     )
     91     if prefer_decoder and "decoder_state" in state:
     92         self.decoder = get_decoder(
     93             image_encoder=self.predictor.model.image_encoder,
     94             decoder_state=state["decoder_state"],
     95             device=device,
     96         )

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\util.py:348, in get_sam_model(model_type='vit_h', device='cuda', checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', return_sam=False, return_state=True)
    342 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
    343     raise RuntimeError(
    344         "mobile_sam is required for the vit-tiny."
    345         "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
    346     )
--> 348 state, model_state = _load_checkpoint(checkpoint_path)
        checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
    349 sam = sam_model_registry[abbreviated_model_type]()
    350 sam.load_state_dict(model_state)

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\util.py:253, in _load_checkpoint(checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt')
    250 custom_pickle = pickle
    251 custom_pickle.Unpickler = _CustomUnpickler
--> 253 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
        custom_pickle = <module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>
        checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
    254 if "model_state" in state:
    255     # Copy the model weights from torch_em's training format.
    256     model_state = state["model_state"]

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\torch\serialization.py:1025, in load(f='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', map_location='cpu', pickle_module=<module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>, weights_only=False, mmap=False, **pickle_load_args={'encoding': 'utf-8'})
   1023             except RuntimeError as e:
   1024                 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> 1025         return _load(opened_zipfile,
        opened_zipfile = <torch.PyTorchFileReader object at 0x000002C5A26A55F0>
        map_location = 'cpu'
        pickle_module = <module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>
        overall_storage = None
        pickle_load_args = {'encoding': 'utf-8'}
   1026                      map_location,
   1027                      pickle_module,
   1028                      overall_storage=overall_storage,
   1029                      **pickle_load_args)
   1030 if mmap:
   1031     f_name = "" if not isinstance(f, str) else f"{f}, "

File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\torch\serialization.py:1446, in _load(zip_file=<torch.PyTorchFileReader object>, map_location='cpu', pickle_module=<module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>, pickle_file='data.pkl', overall_storage=None, **pickle_load_args={'encoding': 'utf-8'})
   1444 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1445 unpickler.persistent_load = persistent_load
-> 1446 result = unpickler.load()
        unpickler = <torch.serialization._load.<locals>.UnpicklerWrapper object at 0x000002C58932D950>
   1448 torch._utils._validate_loaded_sparse_tensors()
   1449 torch._C._log_api_usage_metadata(
   1450     "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
   1451 )

UnpicklingError: NEWOBJ class argument must be a type, not NoneType
Invalid file selected. Please try again.

Do I need to train the model differently? I have made sure that the model matches the one I used to finetune (vit_h) and it gives me the same error even on the device I used for training.

constantinpape commented 1 month ago

There is an issue in the notebook that can lead to issues in unpickling (if this happens or not seems a bit stochastic).

We will fix it and let you know as soon as this is done, see #676.

Sorry about that and thanks for the report.

anwai98 commented 1 month ago

Hi @Cryaaa,

We have added a fix to the finetuning notebook now. Could you re-run the finetuning using the updated notebook?

Apologies for the inconvenience. Let us know if this worked out for you!

Cryaaa commented 1 month ago

Hi @Cryaaa,

We have added a fix to the finetuning notebook now. Could you re-run the finetuning using the updated notebook?

Apologies for the inconvenience. Let us know if this worked out for you!

Wow thanks for the quick fix! I'll give it a shot and let you know ASAP how it turned out

Cryaaa commented 1 month ago

Happy to report that this did do the trick and I can now load the model in napari, thanks for the quick fix!