Closed rodrigo-pena closed 1 month ago
Hi @rodrigo-pena,
Could you share with us the entire error trace from the finetuning notebook?
(I think I have a suspicion on what's happening - torch-em
expects the inputs to be channels first, however would be good to verify this by looking at the trace)
In addition, could you attach a screenshot of the finetuning GUI widget in napari, where you initialize the parameters for finetuning?
Hi @anwai98, here's the code cell I run:
batch_size = 1 # the training batch size
patch_shape = (512, 512) # the size of patches for training
# Train an additional convolutional decoder for end-to-end automatic instance segmentation
train_instance_segmentation = True
# Label transform is used to convert the ground-truth labels to the desired instances for finetuning Segment Anything.
# Or to learn the foreground and distances to the object centers and object boundaries for automatic segmentation.
if train_instance_segmentation:
# Computes the distance transform for objects to jointly perform the additional decoder-based automatic
# instance segmentation (AIS) and finetune Segment Anything.
label_transform = PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
instances=True,
min_size=25,
)
else:
# Ensures the individual object instances to finetune the classical Segment Anything.
label_transform = torch_em.transform.label.connected_components
train_loader = torch_em.default_segmentation_loader(
raw_paths=image_dir,
raw_key=raw_key,
label_paths=segmentation_dir,
label_key=label_key,
patch_shape=patch_shape,
batch_size=batch_size,
ndim=2,
is_seg_dataset=True,
rois=train_roi,
label_transform=label_transform,
shuffle=True,
raw_transform=sam_training.identity,
)
val_loader = torch_em.default_segmentation_loader(
raw_paths=image_dir,
raw_key=raw_key,
label_paths=segmentation_dir,
label_key=label_key,
patch_shape=patch_shape,
batch_size=batch_size,
ndim=2,
is_seg_dataset=True,
rois=val_roi,
label_transform=label_transform,
shuffle=True,
raw_transform=sam_training.identity,
)
And here's the corresponding the full error traceback:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[6], line 24
20 else:
21 # Ensures the individual object instances to finetune the classical Segment Anything.
22 label_transform = torch_em.transform.label.connected_components
---> 24 train_loader = torch_em.default_segmentation_loader(
25 raw_paths=image_dir,
26 raw_key=raw_key,
27 label_paths=segmentation_dir,
28 label_key=label_key,
29 patch_shape=patch_shape,
30 batch_size=batch_size,
31 ndim=2,
32 is_seg_dataset=True,
33 rois=train_roi,
34 label_transform=label_transform,
35 shuffle=True,
36 raw_transform=sam_training.identity,
37 )
38 val_loader = torch_em.default_segmentation_loader(
39 raw_paths=image_dir,
40 raw_key=raw_key,
(...)
50 raw_transform=sam_training.identity,
51 )
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:212](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=211), in default_segmentation_loader(raw_paths, raw_key, label_paths, label_key, batch_size, patch_shape, label_transform, label_transform2, raw_transform, transform, dtype, label_dtype, rois, n_samples, sampler, ndim, is_seg_dataset, with_channels, with_label_channels, verify_paths, **loader_kwargs)
189 def default_segmentation_loader(
190 raw_paths,
191 raw_key,
(...)
210 **loader_kwargs,
211 ):
--> 212 ds = default_segmentation_dataset(
213 raw_paths=raw_paths,
214 raw_key=raw_key,
215 label_paths=label_paths,
216 label_key=label_key,
217 patch_shape=patch_shape,
218 label_transform=label_transform,
219 label_transform2=label_transform2,
220 raw_transform=raw_transform,
221 transform=transform,
222 dtype=dtype,
223 label_dtype=label_dtype,
224 rois=rois,
225 n_samples=n_samples,
226 sampler=sampler,
227 ndim=ndim,
228 is_seg_dataset=is_seg_dataset,
229 with_channels=with_channels,
230 with_label_channels=with_label_channels,
231 verify_paths=verify_paths,
232 )
233 return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:274](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=273), in default_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, patch_shape, label_transform, label_transform2, raw_transform, transform, dtype, label_dtype, rois, n_samples, sampler, ndim, is_seg_dataset, with_channels, with_label_channels, verify_paths)
269 transform = _get_default_transform(
270 raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
271 )
273 if is_seg_dataset:
--> 274 ds = _load_segmentation_dataset(
275 raw_paths,
276 raw_key,
277 label_paths,
278 label_key,
279 patch_shape=patch_shape,
280 raw_transform=raw_transform,
281 label_transform=label_transform,
282 label_transform2=label_transform2,
283 transform=transform,
284 rois=rois,
285 n_samples=n_samples,
286 sampler=sampler,
287 ndim=ndim,
288 dtype=dtype,
289 label_dtype=label_dtype,
290 with_channels=with_channels,
291 with_label_channels=with_label_channels,
292 )
293 else:
294 ds = _load_image_collection_dataset(
295 raw_paths,
296 raw_key,
(...)
308 label_dtype=label_dtype,
309 )
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:102](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=101), in _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs)
100 if isinstance(rois, tuple):
101 assert all(isinstance(roi, slice) for roi in rois)
--> 102 ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
103 else:
104 assert len(raw_paths) > 0
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/data/segmentation_dataset.py:66](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/data/segmentation_dataset.py#line=65), in SegmentationDataset.__init__(self, raw_path, raw_key, label_path, label_key, patch_shape, raw_transform, label_transform, label_transform2, transform, roi, dtype, label_dtype, n_samples, sampler, ndim, with_channels, with_label_channels)
64 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
65 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
---> 66 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
68 self.shape = shape_raw
69 self.roi = roi
AssertionError: (17, 1040, 1392, 3), (17, 1040, 1392)
I'm also adding the screenshot from the GUI and the error I get when I try to fine-tune the vit_l_lm
model via it (I'm censoring parts of the paths to avoid disclosing locations in my HPC system). The error appears right after I click the "Start Training" button
Thanks for sharing the trace and the GUI screenshot.
Re: notebook error trace: Could you also provide information about the paths you provide as inputs (would help me understand the nature of inputs you are trying to pass), i.e.:
raw_key
and label_key
you are trying to pass to the torch_em.default_segmentation_loader
(I assume that image_dir
and segmentation_dir
are the paths displayed in the GUI screenshot - /.../train_images
and /.../train_labels
respectively)?(17, 1040, 1392, 3)
and the respective labels are of shape (17, 1040, 1392)
). Could you confirm if this the case?Re: GUI finetuning screenshot: I think one thing you are missing is passing values to Image data key
and Label data key
, which would search the inputs for you and collect them for creating the dataloaders. (I think this question would probably be answered once I get answers to the aforementioned questions)
Here's the cell where I define raw_key
and label_key
:
# Load images from multiple files in folder via pattern (here: all tif files)
raw_key, label_key = "*.tif", "*.tif"
# Alternative: if you have tif stacks you can just set raw_key and label_key to None
# raw_key, label_key = None, None
# The 'roi' argument can be used to subselect parts of the data.
# Here, we use it to select the first 85% of images for the training split and the other frames for the validation split.
n_images = len(image_paths)
n_train = int(0.85 * n_images )
train_roi = np.s_[:n_train, :, :]
val_roi = np.s_[n_train:, :, :]
Indeed, image_dir
points to /.../train_images
and label_dir
to /.../train_labels
. And all images in there are .tif
files:
DATA_FOLDER = os.path.join(HOME, 'data', 'myocount-validation-data')
image_dir = os.path.join(DATA_FOLDER, 'train_images')
segmentation_dir = os.path.join(DATA_FOLDER, 'train_labels')
Regarding your second bullet point, I don't know if we can interpret the images as 3d. They have indeed 3 channels (one for the nuclei, one for the cytoplasm, and a dummy one), but each channel has very different information. So these images are different than one would get in e.g., z-stacks, which I would consider 3d data.
As for the GUI screenshot, thanks for catching the missing ".tif" keys. However, I tried again (this time properly setting ".tif" under "Image data key" and "Label data key") and the error remains the same
Okay, I can reproduce the error and I think I know what's causing this. It's the way we are fetching the multi-channel inputs, which are getting stacked together and hence causing the issue. Let's fix the issues one-by-one.
For the notebook-based finetuning on your provided data:
OPTION 1:
We are currently using torch_em
-supported SegmentationDataset
.
This expects to change the inputs to have channels first (i.e. the input images currently are in the structure of (1040, 1392, 3)
. We need to change them to (3, 1040, 1392)
). To do this, you could do the following:
from pathlib import Path
import imageio.v3 as imageio
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
for image_path in image_paths:
pxpath = Path(image_path)
target_path = os.path.join(pxpath.parent, "image_preprocessed_dir", pxpath.name)
image = imageio.imread(image_path)
image = image.transpose(2, 0, 1)
imageio.imwrite(target_path, image)
Next, we need to update the structure in which we pass the inputs to the dataloader:
image_dir = ... # NOTE: this needs to point to "image_preprocessed_dir" now
# Fetching all inputs in respective directories
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(labels_dir, "*.tif")))
# To make valid splits for the inputs
n_images = len(image_paths)
n_train = int(0.85 * n_images)
# Lets's split the input paths
train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]
# We create our dataloaders
train_loader = torch_em.default_segmentation_loader(
raw_paths=train_image_paths,
raw_key=None,
label_paths=train_label_paths,
label_key=None,
batch_size=1,
patch_shape=(512, 512),
ndim=2,
is_seg_dataset=True,
with_channels=True,
)
val_loader = ... # same snippet as above with "val"-related paths
OPTION 2:
We can choose to leave the inputs as it is and use another torch_em
-supported dataset (ImageCollectionDataset
) to fetch inputs:
image_dir = ... # NOTE: this needs to point to old image directory here
# Fetching all inputs in respective directories
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(labels_dir, "*.tif")))
# To make valid splits for the inputs
n_images = len(image_paths)
n_train = int(0.85 * n_images)
# Lets's split the input paths
train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]
# We create our dataloaders
train_loader = torch_em.default_segmentation_loader(
raw_paths=train_image_paths,
raw_key=None,
label_paths=train_label_paths,
label_key=None,
batch_size=1,
patch_shape=(512, 512),
ndim=2,
is_seg_dataset=False,
n_samples=50, # This oversamples the inputs, i.e. reruns the dataset object over the inputs (once it has gone through all the images) in order to to fetch more patches
)
val_loader = ... # same snippet as above with "val"-related paths
PS. I haven't tested the code written here, but in theory should work out. Let us know how it goes.
Ah and regarding the issue with GUI-based finetuning, I see that the only checks we do in the code is to verify if the provided path exists or not.
Could you double check two things:
a) If the paths exist for both label and image directory via your terminal using (just to be doubly sure): 'python -c 'import os; print(os.path.exists("<PATHS_FROM_GUI>"))'
b) Could you also report the following: python -c 'import micro_sam; print(micro_sam.__version__)'
Thanks for the options, I'll try them tomorrow and report the results here. Does the napari plugin "Annotator 2D" use a different dataloader under the hood? After all, I was able to compute embeddings and run the pre-trained vit_l_lm
model on it to assist with the annotations?
About the points to double check re: fine-tuning GUI:
a) Returns True
b) Returns 1.0.1
Does the napari plugin "Annotator 2D" use a different dataloader under the hood?
The "Annotator 2d" just takes the input path for the image and opens it up the array in a layer. It does not use the dataloader schema.
After all, I was able to compute embeddings and run the pre-trained vit_l_lm model on it to assist with the annotations?
Yes, it's because of the aforementioned reason. The dataloaders are a bit specific to the finetuning schema, but with the recommended options to adapt, they should work now. (I will look into increasing support for different image types in our dataloader tutorial notebook)
About the points to double check re: fine-tuning GUI: a) Returns True b) Returns 1.0.1
Thanks for checking this out. Okay, it seems that the directories are visible to the model. Hmm, that's strange. Can you confirm if you have made the installation from source / conda-forge?
I have created a separate mamba environment for micro-sam
and installed the package in it via mamba using the conda-forge channel (plus pytorch with GPU capability), according to these instructions
Hi @anwai98 , I've tried Options 1 and 2 that you gave me above, and both seem to work in creating the dataloaders. But it is still not clear for me why in one options I have specify the flag with_channels=True
and in the other n_samples=50
. Are these documented somewhere?
Regardless, there are some differences in the created dataloaders: when using torch_em.util.debug.check_loader
to display some images, the loader from Option 1 raises a warning ("Multi-channel input data is not yet supported, will only show channel 0"), while the one from Option 2 displays images without warning.
Moving on, when I run the following training code
# All hyperparameters for training.
n_objects_per_batch = 5 # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
n_epochs = 10 # how long we train (in epochs)
# The model_type determines which base model is used to initialize the weights that are finetuned.
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
model_type = "vit_l_lm"
# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
checkpoint_name = "vit_l_lm_finetuned"
# Run training
sam_training.train_sam(
name=checkpoint_name,
save_root=os.path.join(HOME, "models"),
model_type=model_type,
checkpoint_path=os.path.join(HOME, "models", "vit_l_lm", "vit_l.pt"),
train_loader=train_loader,
val_loader=val_loader,
n_epochs=n_epochs,
n_objects_per_batch=n_objects_per_batch,
with_segmentation_decoder=train_instance_segmentation,
device=device,
)
I get the error traceback
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[29], line 2
1 # Run training
----> 2 sam_training.train_sam(
3 name=checkpoint_name,
4 save_root=os.path.join(HOME, "models"),
5 model_type=model_type,
6 checkpoint_path=os.path.join(HOME, "models", "vit_l_lm", "vit_l.pt"),
7 train_loader=train_loader,
8 val_loader=val_loader,
9 n_epochs=n_epochs,
10 n_objects_per_batch=n_objects_per_batch,
11 with_segmentation_decoder=train_instance_segmentation,
12 device=device,
13 )
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:184](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=183), in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
128 def train_sam(
129 name: str,
130 model_type: str,
(...)
148 pbar_signals: Optional[QObject] = None,
149 ) -> None:
150 """Run training for a SAM model.
151
152 Args:
(...)
182 pbar_signals: Controls for napari progress bar.
183 """
--> 184 _check_loader(train_loader, with_segmentation_decoder)
185 _check_loader(val_loader, with_segmentation_decoder)
187 device = get_device(device)
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:80](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=79), in _check_loader(loader, with_segmentation_decoder)
74 if n_channels_y != 4:
75 raise ValueError(
76 "Invalid number of channels in the target data from the data loader. "
77 "Expect 4 channel for training with an instance segmentation decoder, "
78 f"but got {n_channels_y} channels."
79 )
---> 80 check_instance_channel(y[:, 0])
82 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
83 if targets_min < 0 or targets_min > 1:
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:64](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=63), in _check_loader.<locals>.check_instance_channel(instance_channel)
60 raise ValueError(
61 "The target channel with the instance segmentation must not have negative values."
62 )
63 if len(unique_vals) == 1:
---> 64 raise ValueError(
65 "The target channel with the instance segmentation must have at least one instance."
66 )
67 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
68 raise ValueError(
69 "All values in the target channel with the instance segmentation must be integer."
70 )
ValueError: The target channel with the instance segmentation must have at least one instance.
I think this might be because some of my annotation masks are all zeros, meaning that no object should be segmented in that image. These all-zero masks were obtained by "skipping" the image in "Image Series Annotator" (i.e., pressing the Next Image button). If the error is because of that, does it mean that I have to manually discard from my dataset all images whose corresponding segmentation masks are all zeros?
Thanks for confirming that the dataloaders now work.
But it is still not clear for me why in one options I have specify the flag with_channels=True and in the other n_samples=50. Are these documented somewhere?
It's the heuristic how we create and handle dataloaders in torch-em
(if you look closely, we have is_seg_dataset
as different values for both options). What option 1 uses SegmentationDataset
, which fetches valid patches (say, of desired patch shape (512, 512)
) multiple times from an image larger than the desired patch shape. The option 2 uses ImageCollectionDataset
, which fetches only one randomly cropped patch per image (so let's say you have 3 training images, you will have 3 patches extracted in total) - which is why we increase n_samples
to a higher integer, so that we run the dataset object more times on the entire data to get valid patches.
We have an introductory documentation for torch-em
datasets here: https://constantinpape.github.io/torch-em/ (we are looking into extending this in the near future). We would be happy to answer any questions from your side on this.
Regardless, there are some differences in the created dataloaders: when using torch_em.util.debug.check_loader to display some images, the loader from Option 1 raises a warning ("Multi-channel input data is not yet supported, will only show channel 0"), while the one from Option 2 displays images without warning.
My suspicion is that happens because matplotlib does not support channels-first plotting of inputs. I can reproduce the error in both cases (it's because the inputs are coming out as tensors with channels-first in either of the two recommended options), but as long as your images visually look fine, the warning is not relevant.
You can check if your loader outputs from both look as expected:
loader = ...
inputs = next(iter(loader))
print(inputs[0].shape) # the inputs should look like: torch.Size([1, 3, 512, 512])
If the error is because of that, does it mean that I have to manually discard from my dataset all images whose corresponding segmentation masks are all zeros?
Yes, exactly. The error notifies you that the provided labels do not have any valid objects to perform finetuning. You can use a sampler for this case supported by torch-em
:
# the provided code with it's default values checks if your labels have atleast one valid object to segment
from torch_em.data.sampler import MinInstanceSampler
loader = torch_em.default_segmentation_loader(
..., # all other arguments
sampler=MinInstanceSampler(),
)
Let us know if this makes the finetuning work.
Update: I have trimmed the dataset to contain only image-mask pairs with at least one instance, and I've used the MinInstanceSampler for the Option 2 dataloader.
However, the sam_training.train_sam(...)
call now gives me the following new error traceback:
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:60](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch/optim/lr_scheduler.py#line=59): UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
warnings.warn(
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:76](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=75): FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = amp.GradScaler() if mixed_precision else None
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py:177](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py#line=176): UserWarning: Constructor arguments for <class 'micro_sam.training.trainable_sam.TrainableSAM'> cannot be deduced.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
warnings.warn(
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py:177](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py#line=176): UserWarning: Constructor arguments for <class 'torch_em.model.unetr.UNETR'> cannot be deduced.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
warnings.warn(
Start fitting for 150 iterations / 10 epochs
with 15 iterations per epoch
Training with mixed precision
Epoch 0: 0%| | 0[/150](<redacted_url>/150) [00:00<?, ?it[/s](<redacted_url>/s)][~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:86](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=85): FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with forward_context():
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:95](~/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=94): FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
with forward_context():
Epoch 0: 2%|▏ | 3[/150](<redacted_url>/150) [00:06<05:33, 2.27s[/it](<redacted_url>/it)]
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[15], line 2
1 # Run training
----> 2 sam_training.train_sam(
3 name=checkpoint_name,
4 save_root=os.path.join(HOME, "models"),
5 model_type=model_type,
6 checkpoint_path=os.path.join(HOME, "models", "vit_b_lm", "vit_b.pt"),
7 train_loader=train_loader_from_segmentation_dataset,
8 # train_loader=train_loader_from_image_collection_dataset,
9 val_loader=val_loader_from_segmentation_dataset,
10 # val_loader=val_loader_from_image_collection_dataset,
11 n_epochs=n_epochs,
12 n_objects_per_batch=n_objects_per_batch,
13 with_segmentation_decoder=train_instance_segmentation,
14 device=device,
15 )
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:282](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=281), in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
279 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
280 trainer_fit_params["progress"] = progress_bar_wrapper
--> 282 trainer.fit(**trainer_fit_params)
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:583](<redacted_url>notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=582), in DefaultTrainer.fit(self, iterations, load_from_checkpoint, epochs, save_every_kth_epoch, progress)
580 pass
582 # Run training and validation for this epoch
--> 583 t_per_iter = train_epoch(progress)
584 current_metric = validate()
586 # perform all the post-epoch steps:
587
588 # apply the learning rate scheduler
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:647](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=646), in DefaultTrainer._train_epoch_mixed(self, progress)
646 def _train_epoch_mixed(self, progress):
--> 647 return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed)
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:89](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=88), in JointSamTrainer._train_epoch_impl(self, progress, forward_context, backprop)
84 self.optimizer.zero_grad()
86 with forward_context():
87 # 1. train for the interactive segmentation
88 (loss, mask_loss, iou_regression_loss, model_iou,
---> 89 sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
91 backprop(loss)
93 self.optimizer.zero_grad()
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/sam_trainer.py:296](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/sam_trainer.py#line=295), in SamTrainer._interactive_train_iteration(self, x, y)
293 def _interactive_train_iteration(self, x, y):
294 n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
--> 296 batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
297 batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
299 loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
300 batched_inputs, y_one_hot,
301 num_subiter=self.n_sub_iteration, multimask_output=multimask_output
302 )
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py:172](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py#line=171), in ConvertToSamInputs.__call__(self, x, y, n_pos, n_neg, get_boxes, n_samples)
170 for image, gt in zip(x, y):
171 gt = gt.squeeze().numpy().astype(np.int64)
--> 172 box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
173 gt, n_samples, prompt_generator,
174 )
176 # check to be sure about the expected size of the no. of elements in different settings
177 if get_boxes:
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py:144](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py#line=143), in ConvertToSamInputs._get_prompt_lists(self, gt, n_samples, prompt_generator)
141 bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
143 # convert the gt to the one-hot-encoded masks for the sampled cell ids
--> 144 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
146 # derive and return the prompts
147 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)
File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/util.py:950](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/util.py#line=949), in segmentation_to_one_hot(segmentation, segmentation_ids)
947 n_ids = int(segmentation.max())
949 else:
--> 950 assert segmentation_ids[0] != 0, "No objects were found."
952 # the segmentation ids have to be sorted
953 segmentation_ids = np.sort(segmentation_ids)
IndexError: index 0 is out of bounds for axis 0 with size 0
For reference, the dummy dataset with only valid image-mask pairs looks like this (with the first three images going to the training split and the last image going to the validation split):
Remark: I don't know if one of the warnings have to do with this, but the cluster from which I'm running this has no access to the Internet, so I have to pre-load the model checkpoints to train from. In sam_training.train_sam(...)
I point to the vit_b_lm
checkpoint via the file vit_l.pt
, but I don't know if and how I should point to the corresponding pre-trained decoder vit_b_decoder.pt
Re: warnings in training: I would say you can ignore them, they aren't anything critical.
Re: error in training: Hmm, seems like we are still getting samples with no foreground objects. That's strange. Can you visually validate from the dataloader outputs if that's the case, and report if you still see images without valid paired labels (remember to use sampler=MinInstanceSampler()
in the torch_em.default_segmentation_dataset
)?
(you can visualize the loaders locally using napari
with the script below:
train_loader = ...
val_loader = ...
from torch_em.util.debug import check_loader
# NOTE: I choose a large number of samples below to visualize all possible inputs
check_loader(train_loader, 50)
check_loader(val_loader, 50)
Re: cluster access to internet: I have a similar setup at my end. To tackle this, what I do is briefly run the scripts on the login node (which does have access to the internet), and while I see the training proceed (after 1/2 iterations), I just stop the job and submit it to the cluster. This would be my recommendation, as it ensures all the necessary automatic downloads (i.e. model checkpoints), and then you do not need to hard-code our model paths. (+ this would also make sure to load the decoder weights)
Following your suggestion, I've checked the data loader outputs on napari.
For Option 1, I saw that initially there were some rare image patches that were cropped from areas with no instance segmentation. So I used sampler=MinInstanceSampler()
for its data loaders, and that seemed to visually solve the problem on the samples I've explored. However, when running training the same error appeared ( IndexError: index 0 is out of bounds for axis 0 with size 0
)
For Option 2 (which already had sampler=MinInstanceSampler()
), I could see sample patches on napari of images that were still cropped from an area without annotated segmentation masks. The training fails in the same way with Option 2.
There is one thing that still may be an issue: only two of the channels in the images are informative (two fluorescence wavelengths). The third is all-zeros, put there by the creators of the dataset so that maybe the images could be read as RGB. Could that be causing the issue?
Actually, I just ran the check_loader
on a bunch of samples from the Option 2 dataloader and could only see good pairs, so maybe I mistook what I saw the first time round (or it was some fluke)
Could you please double check this for Option 2 with the sampler? (I would recommend to stick with Option 2 as this works and is a bit more simpler compared to changing axes around)
If it still causes issues, we might need to investigate this and fix this in torch-em
.
In addition, could you report the output of the following: python -c "import torch_em; print(torch_em.__version__)"
Good morning, @anwai98. By the way, I appreciate this level of support. It's great to see and really motivates me to use the code base y'all developed.
Focusing on Option 2, I re-ran the training code today and this time it trained for 1.25 epochs before crashing with the same IndexError
as before.
So once again I inspected the image/label pairs coming from the training and validation data loaders on napari through the check_loaders
function. Out of the 100 images i saw, they were all good pairs, except for 1. I took a screenshot of the thumbnails of the bad pair loaded on napari:
As you can see, it seems that the segmentation mask is all-zeros. Weirdly enough, though, if I split the stack into three channels, two of the channels in the segmentation mask are all-ones:
I zoomed into the extracted image patch,
to search for which original image it had been extracted. I ended up finding it (red square, rotated 180 degrees):
As you can see, it is indeed a patch with no annotated segmentation instances. So I wonder how the MinInstanceSampler()
is letting this pass.
This behavior seems to me uncommon enough that sometimes the trainer will go through a whole epoch without seeing such bad pairs, as stated in the beginning. But it will still happen sometime during training, leading to a crash.
Hi @rodrigo-pena,
Thanks for getting back with the detailed feedback.
Re: sampler misses to ignore empty labels for one patch: I remember we recently fixed this in torch-em
. The first step here would be to check if we have the latest torch-em
installed. Could you run python -c "import torch_em; print(torch_em.__version__)"
and share the output with us?
This command returns
~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
0.7.3
Hmm, you do have the version after we made the fix (I double-checked, and we already made the changes in 0.6.2
). I'll try to reproduce this over sparse labels myself and see what's going on.
(and I will also come back with a bit more theoretical feedback, eg. using dense labels for training automatic instance segmentation, the question of usage of channels in your case, etc. later once we take care of the above technical issue)
I tried to reproduce the mentioned effect on sparse labels and I have a few suspicions which might potentially get rid of this issue (in short: there might be tiny pixels getting considered as valid objects, which are passed to the model and might be somehow causing the issue).
Could you try again by adding an additional argument to the sampler we pass to the loader: sampler=MinInstanceSampler(min_size=25)
?
NOTE: Remember to use the sampler in both training and validation loaders!
EDIT: I forgot to mention something. For this, you need to install torch-em
from source. You can just do this in your current working environment via terminal using:
git clone https://github.com/constantinpape/torch-em
cd torch-em
pip install -e . --no-deps
Ok, that did the trick. Training has finished and the resulting AIS predictions look promising:
However, during the process of re-installing some packages, something broke in imageio
and and I couldn't write tiffs to file anymore. Here are the steps I took:
micro-sam
environment. <0.5
(related to the other issue #695 )torch-em
from source according to your previous comment.After these 3 steps, whenever I called image_io.imwrite()
I got the error AttributeError: 'TiffWriter' object has no attribute 'write'
. I then checked the tifffile package version and it was a very old one: 2020.6.3
. Back then, TiffWriter
didn't have indeed a 'write'
; it was called 'save'
instead. I had to force pip install tifffile --upgrade
to then fix the imageio
error. Could the reason for this mismatch be the --no-deps
flag when installing torch-em
? I don't know what other functionality may be affected by further dependency mismatches.
Yayy, that's great to hear that the training works now.
Re: broken imageio
Could the reason for this mismatch be the --no-deps flag when installing torch-em?
This should not be responsible for breaking an existing dependency, because the idea of using --no-deps
is exactly to avoid fetching additional dependencies. It might have broken while downgrading napari
, but good to know that upgrading tifffile
works.
And a few more fundamental points to share (which in theory are good to work with micro-sam
for best results):
micro-sam
expects dense annotations, else the AIS results might not be as expected.
micro-sam
, you used the same 3-channel images. Now, seeing the annotations provided by you out of the dataloader, looks like it works. I'd be inclined to say that this should also then work for finetuning on the same 3-channel images.PS. We are happy to have all the feedback from you. We have a better idea now of the documentation and notebooks improvements we want to make to improve user experience. Thanks for your patience and interest in micro-sam
.
Let us know if there's something else you would like to discuss.
For a summary, then, the solution that seems to work requires:
micro-sam
from conda-forge
torch-em
from source (with dependencies, to be sure)torch_em.default_segmentation_loader
with arguments:
is_seg_dataset=False
to induce creating an ImageCollectionDataset
n_samples=N
, where N
is an appropriately large number to allow a representative collection of random patches from the training imagessampler=MinInstanceSampler(min_size=M)
, where M
is an appropriately large number to avoid segmentation regions that are too small and could lead to errorsI just wanted to make sure I understand the concept of "dense annotations": does this mean annotating everything that should be segmented in the image, or does that mean having the segmented instances occupy a fraction of the image similar to the background?
Where do you plan to ad the improvements to the documentation? on the example finetuning notebooks of micro-sam
or in the example notebooks of torch-em
?
Thanks for the nice summary. I'll leave a few mentions:
- Installing
micro-sam
from conda-forge- Installing
torch-em
from source (with dependencies, to be sure)
That's correct. To add to this, we are planning to make a release rather soon which would merge point 2.
into 1.
. Then, ideally a user would only need to install micro-sam
(from conda-forge / from source) and things should work as expected.
- Removing from the training dataset any image/mask pair for which the mask is all-zeros (i.e., has no segmentation instances other than the background.
This is the reason we use sampler
in torch_em.default_segmentation_loader
. As long as you have the sampler, you don't need to remove empty label/image pairs. The sampler should take care of it.
And Re: dataloaders:
sampler=MinInstanceSampler(min_size=M)
, where M is an appropriately large number to avoid segmentation regions that are too small and could lead to errors
This should match the value to the minimum number of pixels an object should have to be considered as valid for finetuning. Also, this should match with the min_size
value in label_transform
.
Where do you plan to ad the improvements to the documentation? on the example finetuning notebooks of micro-sam or in the example notebooks of torch-em?
I would like to broaden the dataloader tutorials in torch-em
notebooks and update micro-sam
notebook to add notes and support for samplers.
EDIT:
I just wanted to make sure I understand the concept of "dense annotations": does this mean annotating everything that should be segmented in the image, or does that mean having the segmented instances occupy a fraction of the image similar to the background?
Yes. I would recommend to fully annotate desired objects as much as possible (ideally all objects) to get the best outcome with our automatic instance segmentation method.
Hi @anwai98,
This is the reason we use sampler in torch_em.default_segmentation_loader. As long as you have the sampler, you don't need to remove empty label/image pairs. The sampler should take care of it.
This does not seem to be the case. I just tried running my training code on the full dataset, including the images with blank label pairs and I get the following error:
Start fitting for 500 iterations / 10 epochs
with 50 iterations per epoch
Training with mixed precision
Epoch 0: 1%| | 6/500 [00:20<22:55, 2.78s/it]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[17], line 2
1 # Run training
----> 2 sam_training.train_sam(
3 name=checkpoint_name,
4 save_root=os.path.join(HOME, "models"),
5 model_type=model_type,
6 checkpoint_path=os.path.join(HOME, "models", "vit_b_lm", "vit_b.pt"),
7 # train_loader=train_loader_from_segmentation_dataset,
8 train_loader=train_loader_from_image_collection_dataset,
9 # val_loader=val_loader_from_segmentation_dataset,
10 val_loader=val_loader_from_image_collection_dataset,
11 n_epochs=n_epochs,
12 n_objects_per_batch=n_objects_per_batch,
13 with_segmentation_decoder=train_instance_segmentation,
14 device=device,
15 )
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/micro_sam/training/training.py:282, in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
279 progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
280 trainer_fit_params["progress"] = progress_bar_wrapper
--> 282 trainer.fit(**trainer_fit_params)
File ~/torch-em/torch_em/trainer/default_trainer.py:586, in DefaultTrainer.fit(self, iterations, load_from_checkpoint, epochs, save_every_kth_epoch, progress)
583 pass
585 # Run training and validation for this epoch
--> 586 t_per_iter = train_epoch(progress)
587 current_metric = validate()
589 # perform all the post-epoch steps:
590
591 # apply the learning rate scheduler
File ~/torch-em/torch_em/trainer/default_trainer.py:650, in DefaultTrainer._train_epoch_mixed(self, progress)
649 def _train_epoch_mixed(self, progress):
--> 650 return self._train_epoch_impl(
651 progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
652 self._backprop_mixed
653 )
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/micro_sam/training/joint_sam_trainer.py:78, in JointSamTrainer._train_epoch_impl(self, progress, forward_context, backprop)
76 n_iter = 0
77 t_per_iter = time.time()
---> 78 for x, y in self.train_loader:
79 labels_instances = y[:, 0, ...].unsqueeze(1)
80 labels_for_unetr = y[:, 1:, ...]
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
671 def _next_data(self):
672 index = self._next_index() # may raise StopIteration
--> 673 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
674 if self._pin_memory:
675 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:52, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
50 data = self.dataset.__getitems__(possibly_batched_index)
51 else:
---> 52 data = [self.dataset[idx] for idx in possibly_batched_index]
53 else:
54 data = self.dataset[possibly_batched_index]
File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:52, in <listcomp>(.0)
50 data = self.dataset.__getitems__(possibly_batched_index)
51 else:
---> 52 data = [self.dataset[idx] for idx in possibly_batched_index]
53 else:
54 data = self.dataset[possibly_batched_index]
File ~/torch-em/torch_em/data/image_collection_dataset.py:194, in ImageCollectionDataset.__getitem__(self, index)
193 def __getitem__(self, index):
--> 194 raw, labels = self._get_sample(index)
195 initial_label_dtype = labels.dtype
197 if self.raw_transform is not None:
File ~/torch-em/torch_em/data/image_collection_dataset.py:185, in ImageCollectionDataset._get_sample(self, index)
182 raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
184 if sample_id > self.max_sampling_attempts:
--> 185 raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
187 # to channel first
188 if have_raw_channels and len(prefix_box) == 0:
RuntimeError: Could not sample a valid batch in 500 attempts
Then, when I train only with the "valid" image-label pairs everything goes smoothly.
It seems to me that the MinInstanceSampler
needs at least one valid segmentation instance in the image to work.
Hi @rodrigo-pena,
RuntimeError: Could not sample a valid batch in 500 attempts
The reported issue originates when the sampler cannot find valid samples after "n" attempts. This can be fixed by increasing the sampling attempts for fetching a valid sample after creating the dataloaders as follows:
train_loader = ...
val_loader = ...
print(train_loader.dataset.max_sampling_attempts) # you should see the current default output as 500, i.e. torch-em will try to run the sampler 500 times in search of a valid input, else throw an error.
train_loader.dataset.max_sampling_attempts = 5000
val_loader.dataset.max_sampling_attempts = 5000
# let's verify whether our argument has been updated
print(train_loader.dataset.max_sampling_attempts) # the output of this line should show the number we increased the attempts to, i.e. 5000
Let us know if this works for you.
Training seems to be working now for the full dataset after setting max_sampling_attempts = 5000
. I'm still puzzled as to why I didn't see this issue show up when I was using only images with at least one segmentation instance. Do you have any idea? Was it pure chance?
Training seems to be working now for the full dataset
Happy to hear that the training works now.
why I didn't see this issue show up when I was using only images with at least one segmentation instance.
My guess is when there are no empty labelled sample-pairs, the sampler might be getting a valid sample in under 500 attempts.
I'm also working with RGB data, but I'm running into a slightly different problem. I have a folder with RGB images in .tif format (reshaped so that the RGB channel is first). I have set up the data loader using both option 1 and 2. Here is my version of option 2:
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(segmentation_dir, "*.tif")))
n_images = len(image_paths)
n_train = int(0.85 * n_images)
train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]
batch_size = 1 # the training batch size
patch_shape = (512, 512) # the size of patches for training
# Train an additional convolutional decoder for end-to-end automatic instance segmentation
train_instance_segmentation = True
# Label transform is used to convert the ground-truth labels to the desired instances for finetuning Segment Anything.
# or, to learn the foreground and distances to the object centers and object boundaries for automatic segmentation.
if train_instance_segmentation:
# Computes the distance transform for objects to jointly perform the additional decoder-based automatic instance segmentation (AIS) and finetune Segment Anything.
label_transform = PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
instances=True,
min_size=5
)
else:
# Ensures the individual object instances.to finetune the clasiscal Segment Anything.
label_transform = torch_em.transform.label.connected_components
train_loader = torch_em.default_segmentation_loader(
raw_paths=train_image_paths,
raw_key=None,
label_paths=train_label_paths,
label_key=None,
patch_shape=patch_shape,
batch_size=batch_size,
ndim=2,
is_seg_dataset=False,
n_samples=50,
raw_transform=sam_training.identity,
sampler=MinInstanceSampler(),
)
val_loader = torch_em.default_segmentation_loader(
raw_paths=val_image_paths,
raw_key=None,
label_paths=val_label_paths,
label_key=None,
patch_shape=patch_shape,
batch_size=batch_size,
ndim=2,
is_seg_dataset=False,
n_samples=50,
raw_transform=sam_training.identity,
sampler=MinInstanceSampler(),
)
When I run the training script, I'm getting the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[68], line 2
1 # Run training
----> 2 sam_training.train_sam(
3 name=checkpoint_name,
4 save_root=os.path.join(training_dir, "models"),
5 model_type=model_type,
6 train_loader=train_loader,
7 val_loader=val_loader,
8 n_epochs=n_epochs,
9 n_objects_per_batch=n_objects_per_batch,
10 with_segmentation_decoder=train_instance_segmentation,
11 device=device,
12 )
File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:184, in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
128 def train_sam(
129 name: str,
130 model_type: str,
(...)
148 pbar_signals: Optional[QObject] = None,
149 ) -> None:
150 """Run training for a SAM model.
151
152 Args:
(...)
182 pbar_signals: Controls for napari progress bar.
183 """
--> 184 _check_loader(train_loader, with_segmentation_decoder)
185 _check_loader(val_loader, with_segmentation_decoder)
187 device = get_device(device)
File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:75, in _check_loader(loader, with_segmentation_decoder)
73 if with_segmentation_decoder:
74 if n_channels_y != 4:
---> 75 raise ValueError(
76 "Invalid number of channels in the target data from the data loader. "
77 "Expect 4 channel for training with an instance segmentation decoder, "
78 f"but got {n_channels_y} channels."
79 )
80 check_instance_channel(y[:, 0])
82 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
ValueError: Invalid number of channels in the target data from the data loader. Expect 4 channel for training with an instance segmentation decoder, but got 1 channels.
When I check the shape of the data loader output, I'm getting 'torch.Size([1, 3, 512, 512])' and 'torch.Size([1, 1, 512, 512])' for the image and label respectively. I've also checked several times to see if it is pulling any labels that have no object, and it always has an object.
@rodrigo-pena : thanks for all your feedback here. This helped us understand corner cases of the training with empty label images better and we will incorporate this into the next release to improve handling of these cases. It seems like it's working now for you so I am closing the issue. p.s.: from what I see here your labels are probably not optimal for finetuning micro_sam for automatic instance segmentation and it would help to label as many objects per image as possible rather than just labeling a subset. Feel free to open a separate issue if you run into bad segmentation quality due to this.
@jalexs82 : I have created a new issue #716 in order to keep the discussion of your problem separate. This makes it easier to keep the discussion focused.
I’m working on fine-tuning micro-sam on some fluorescence imaging of myotubes. For prototyping my pipeline, I’ve downloaded the image data associated with the MyoCount paper [Murphy et al. 2019]. These images are .tif files of shape
(1040, 1392, 3)
, with one channel for nulclei, one for cytoplasm and one dummy channel full of zeros.Next, I hand-annotated those images myself using micro-sam’s napari plugin “Annotator 2D”. The saved annotation masks are single-channel masks of shape
(1040, 1392)
.I then proceeded to begin adapting the sample fine-tuning notebook to make the
vit_l_lm
model fit my desired annotations better. However, when it came to defining the data loaders, I reached an error which I cannot solve. It seems that under the hoodtorch_em.default_segmentation_loader
is trying to assert if the images and annotation masks are the same size (they are not, as explained above):Since
torch_em.default_segmentation_loader
loads the images and labels directly from their directory, is the solution to create two dummy channels on the annotation masks so that they are the same size as their respective images? Are there other issues I should expect when working with multi-channel fluorescence imaging (as opposed to single channel phase contrast or EM)?As a side note, I tried also fine-tuning via the provided napari plugin, but I stumble upon a different error:
“The path to the raw data is missing or does not exist. The path to label data is missing or does not exist.”
I wonder what the actual underlying error is, because the path obviously exists, as I can load and display the images and annotation masks in the fine-tuning notebook.