Closed Cryaaa closed 3 months ago
Hi @Cryaaa ,
in the plugin the finetuned model can be used by entering the filepath in the custom weights path
field in the Embedding Settings
. See the screenshot below:
In the CLI it can be passed with the option --checkpoint
.
Let us know if you run into any other issues with the finetuning.
Thanks for the quick reply, I somehow assumed it should shot up in the models drop down for some reason.
I am now running into another problem though. I'm specifying the path to the best model checkpoint created when running the finetuning notebook but when I try to calculate the embedding I get this error:
---------------------------------------------------------------------------
UnpicklingError Traceback (most recent call last)
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:613, in create_worker.<locals>.reraise(e=UnpicklingError('NEWOBJ class argument must be a type, not NoneType'))
612 def reraise(e):
--> 613 raise e
e = UnpicklingError('NEWOBJ class argument must be a type, not NoneType')
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:175, in WorkerBase.run(self=<napari._qt.qthreading.FunctionWorker object>)
173 warnings.filterwarnings("always")
174 warnings.showwarning = lambda *w: self.warned.emit(w)
--> 175 result = self.work()
self = <napari._qt.qthreading.FunctionWorker object at 0x000002C5A24B01F0>
176 if isinstance(result, Exception):
177 if isinstance(result, RuntimeError):
178 # The Worker object has likely been deleted.
179 # A deleted wrapped C/C++ object may result in a runtime
180 # error that will cause segfault if we try to do much other
181 # than simply notify the user.
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\superqt\utils\_qthreading.py:354, in FunctionWorker.work(self=<napari._qt.qthreading.FunctionWorker object>)
353 def work(self) -> _R:
--> 354 return self._func(*self._args, **self._kwargs)
self._func = <function EmbeddingWidget.__call__.<locals>.compute_image_embedding at 0x000002C5A262D760>
self = <napari._qt.qthreading.FunctionWorker object at 0x000002C5A24B01F0>
self._args = ()
self._kwargs = {}
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\sam_annotator\_widgets.py:1094, in EmbeddingWidget.__call__.<locals>.compute_image_embedding()
1091 pbar_signals.pbar_total.emit(total)
1092 pbar_signals.pbar_description.emit(description)
-> 1094 state.initialize_predictor(
state = AnnotatorState(image_embeddings=None, predictor=None, image_shape=(496, 676), embedding_path=None, data_signature=None, amg=None, amg_state=None, decoder=None, current_track_id=None, lineage=None, committed_lineages=None, widgets={'embeddings': <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>, 'prompts': <Container ()>, 'segment': <FunctionGui segment(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), batched: bool = False) -> None>, 'autosegment': <micro_sam.sam_annotator._widgets.AutoSegmentWidget object at 0x000002C5A2213BE0>, 'commit': <FunctionGui commit(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), layer: str = 'current_object', preserve_committed: bool = True, commit_path: Optional[pathlib.Path] = None) -> None>, 'clear': <FunctionGui clear(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>})) -> None>}, z_range=None)
image_data = array([[35.69919, 35.82557, ..., 35.77521, 35.5687 ],
[35.54334, 35.84689, ..., 35.92925, 35.62788],
...,
[38.29271, 38.98734, ..., 37.39376, 37.19536],
[38.99753, 38.71956, ..., 36.66956, 37.22879]])
self.model_type = 'vit_h'
save_path = None
ndim = 2
self = <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>
self.device = 'auto'
self.custom_weights = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
tile_shape = None
halo = None
self.prefer_decoder = True
pbar_signals = <micro_sam.sam_annotator._widgets.PBarSignals object at 0x000002C5A24B05E0>
1095 image_data, model_type=self.model_type, save_path=save_path, ndim=ndim,
1096 device=self.device, checkpoint_path=self.custom_weights, tile_shape=tile_shape, halo=halo,
1097 prefer_decoder=self.prefer_decoder, pbar_init=pbar_init,
1098 pbar_update=lambda update: pbar_signals.pbar_update.emit(update),
1099 )
1100 pbar_signals.pbar_stop.emit()
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\sam_annotator\_state.py:87, in AnnotatorState.initialize_predictor(self=AnnotatorState(image_embeddings=None, predictor=...t 0x000002C5A242F100>})) -> None>}, z_range=None), image_data=array([[35.69919, 35.82557, ..., 35.77521, 35.56... [38.99753, 38.71956, ..., 36.66956, 37.22879]]), model_type='vit_h', ndim=2, save_path=None, device='auto', predictor=None, decoder=None, checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', tile_shape=None, halo=None, precompute_amg_state=False, prefer_decoder=True, pbar_init=<function EmbeddingWidget.__call__.<locals>.compute_image_embedding.<locals>.pbar_init>, pbar_update=<function EmbeddingWidget.__call__.<locals>.compute_image_embedding.<locals>.<lambda>>)
85 # Initialize the model if necessary.
86 if predictor is None:
---> 87 self.predictor, state = util.get_sam_model(
self.predictor = None
self = AnnotatorState(image_embeddings=None, predictor=None, image_shape=(496, 676), embedding_path=None, data_signature=None, amg=None, amg_state=None, decoder=None, current_track_id=None, lineage=None, committed_lineages=None, widgets={'embeddings': <micro_sam.sam_annotator._widgets.EmbeddingWidget object at 0x000002C5A22101F0>, 'prompts': <Container ()>, 'segment': <FunctionGui segment(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), batched: bool = False) -> None>, 'autosegment': <micro_sam.sam_annotator._widgets.AutoSegmentWidget object at 0x000002C5A2213BE0>, 'commit': <FunctionGui commit(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>}), layer: str = 'current_object', preserve_committed: bool = True, commit_path: Optional[pathlib.Path] = None) -> None>, 'clear': <FunctionGui clear(viewer: napari.viewer.Viewer = Viewer(camera=Camera(center=(0.0, 28.644995401651613, 371.0334902685355), zoom=1.5497905470650835, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True), cursor=Cursor(position=(82.52324276176802, -33.537306821193056), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=495.0, step=1.0), RangeTuple(start=0.0, stop=675.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(127.0, 127.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False), layers=[<Labels layer 'current_object' at 0x2c598db8810>, <Labels layer 'auto_segmentation' at 0x2c59e8b6750>, <Labels layer 'committed_objects' at 0x2c59e7ed410>, <Points layer 'point_prompts' at 0x2c59f032e10>, <Shapes layer 'prompts' at 0x2c5a2103f90>, <Image layer '02280' at 0x2c5a2461f90>], help='use <2> for transform', status='', tooltip=Tooltip(visible=False, text=''), theme='dark', title='napari', mouse_over_canvas=False, mouse_move_callbacks=[], mouse_drag_callbacks=[], mouse_double_click_callbacks=[], mouse_wheel_callbacks=[<function dims_scroll at 0x000002C5E9092CA0>], _persisted_mouse_event={}, _mouse_drag_gen={}, _mouse_wheel_gen={}, _keymap={<KeyBinding at 0x2c5a242b090: S>: <function _AnnotatorBase._create_keybindings.<locals>._segment at 0x000002C5A242EE80>, <KeyBinding at 0x2c5a23bf8d0: C>: <function _AnnotatorBase._create_keybindings.<locals>._commit at 0x000002C5A242EFC0>, <KeyBinding at 0x2c5a2431690: T>: <function _AnnotatorBase._create_keybindings.<locals>._toggle_label at 0x000002C5A242F060>, <KeyBinding at 0x2c5a2431210: Shift+C>: <function _AnnotatorBase._create_keybindings.<locals>._clear_annotations at 0x000002C5A242F100>})) -> None>}, z_range=None)
util = <module 'micro_sam.util' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\site-packages\\micro_sam\\util.py'>
device = 'auto'
model_type = 'vit_h'
checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
88 device=device, model_type=model_type,
89 checkpoint_path=checkpoint_path, return_state=True
90 )
91 if prefer_decoder and "decoder_state" in state:
92 self.decoder = get_decoder(
93 image_encoder=self.predictor.model.image_encoder,
94 decoder_state=state["decoder_state"],
95 device=device,
96 )
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\util.py:348, in get_sam_model(model_type='vit_h', device='cuda', checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', return_sam=False, return_state=True)
342 if abbreviated_model_type == "vit_t" and not VIT_T_SUPPORT:
343 raise RuntimeError(
344 "mobile_sam is required for the vit-tiny."
345 "You can install it via 'pip install git+https://github.com/ChaoningZhang/MobileSAM.git'"
346 )
--> 348 state, model_state = _load_checkpoint(checkpoint_path)
checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
349 sam = sam_model_registry[abbreviated_model_type]()
350 sam.load_state_dict(model_state)
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\micro_sam\util.py:253, in _load_checkpoint(checkpoint_path='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt')
250 custom_pickle = pickle
251 custom_pickle.Unpickler = _CustomUnpickler
--> 253 state = torch.load(checkpoint_path, map_location="cpu", pickle_module=custom_pickle)
custom_pickle = <module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>
checkpoint_path = 'D:/OneDrive/Documents/PhD Jesse/Embryonic_organoid_prediction/SAM Finetuning/fine_tuned_BF_segmenter_vit_h_1000_training/best.pt'
254 if "model_state" in state:
255 # Copy the model weights from torch_em's training format.
256 model_state = state["model_state"]
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\torch\serialization.py:1025, in load(f='D:/OneDrive/Documents/PhD Jesse/Embryonic_organo...ne_tuned_BF_segmenter_vit_h_1000_training/best.pt', map_location='cpu', pickle_module=<module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>, weights_only=False, mmap=False, **pickle_load_args={'encoding': 'utf-8'})
1023 except RuntimeError as e:
1024 raise pickle.UnpicklingError(UNSAFE_MESSAGE + str(e)) from None
-> 1025 return _load(opened_zipfile,
opened_zipfile = <torch.PyTorchFileReader object at 0x000002C5A26A55F0>
map_location = 'cpu'
pickle_module = <module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>
overall_storage = None
pickle_load_args = {'encoding': 'utf-8'}
1026 map_location,
1027 pickle_module,
1028 overall_storage=overall_storage,
1029 **pickle_load_args)
1030 if mmap:
1031 f_name = "" if not isinstance(f, str) else f"{f}, "
File ~\mambaforge-pypy3\envs\micro_sam_env\Lib\site-packages\torch\serialization.py:1446, in _load(zip_file=<torch.PyTorchFileReader object>, map_location='cpu', pickle_module=<module 'pickle' from 'C:\\Users\\savill\\mambaforge-pypy3\\envs\\micro_sam_env\\Lib\\pickle.py'>, pickle_file='data.pkl', overall_storage=None, **pickle_load_args={'encoding': 'utf-8'})
1444 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1445 unpickler.persistent_load = persistent_load
-> 1446 result = unpickler.load()
unpickler = <torch.serialization._load.<locals>.UnpicklerWrapper object at 0x000002C58932D950>
1448 torch._utils._validate_loaded_sparse_tensors()
1449 torch._C._log_api_usage_metadata(
1450 "torch.load.metadata", {"serialization_id": zip_file.serialization_id()}
1451 )
UnpicklingError: NEWOBJ class argument must be a type, not NoneType
Invalid file selected. Please try again.
Do I need to train the model differently? I have made sure that the model matches the one I used to finetune (vit_h) and it gives me the same error even on the device I used for training.
There is an issue in the notebook that can lead to issues in unpickling (if this happens or not seems a bit stochastic).
We will fix it and let you know as soon as this is done, see #676.
Sorry about that and thanks for the report.
Hi @Cryaaa,
We have added a fix to the finetuning notebook now. Could you re-run the finetuning using the updated notebook?
Apologies for the inconvenience. Let us know if this worked out for you!
Hi @Cryaaa,
We have added a fix to the finetuning notebook now. Could you re-run the finetuning using the updated notebook?
Apologies for the inconvenience. Let us know if this worked out for you!
Wow thanks for the quick fix! I'll give it a shot and let you know ASAP how it turned out
Happy to report that this did do the trick and I can now load the model in napari, thanks for the quick fix!
First of all I must say that you made it super easy to train a finetuned model and I've had success running it from a notebook, which was also the way I finetuned the model. Right now I cannot find a way to use this self tuned model in the plugin and was wondering if I just missed the part n the documentation or if you could add a paragraph on how to do this? I want to have it available in the plugin so the rest of my lab could use it on their data with the image series annotation plugin. Thanks in advance!