computational-cell-analytics / micro-sam

Segment Anything for Microscopy
https://computational-cell-analytics.github.io/micro-sam/
MIT License
365 stars 46 forks source link

Problem adapting finetuning notebook to multi-channel fluorescence imaging data #701

Closed rodrigo-pena closed 1 month ago

rodrigo-pena commented 1 month ago

I’m working on fine-tuning micro-sam on some fluorescence imaging of myotubes. For prototyping my pipeline, I’ve downloaded the image data associated with the MyoCount paper [Murphy et al. 2019]. These images are .tif files of shape (1040, 1392, 3), with one channel for nulclei, one for cytoplasm and one dummy channel full of zeros.

Next, I hand-annotated those images myself using micro-sam’s napari plugin “Annotator 2D”. The saved annotation masks are single-channel masks of shape (1040, 1392).

I then proceeded to begin adapting the sample fine-tuning notebook to make the vit_l_lm model fit my desired annotations better. However, when it came to defining the data loaders, I reached an error which I cannot solve. It seems that under the hood torch_em.default_segmentation_loader is trying to assert if the images and annotation masks are the same size (they are not, as explained above):

AssertionError: (17, 1040, 1392, 3), (17, 1040, 1392)

Since torch_em.default_segmentation_loader loads the images and labels directly from their directory, is the solution to create two dummy channels on the annotation masks so that they are the same size as their respective images? Are there other issues I should expect when working with multi-channel fluorescence imaging (as opposed to single channel phase contrast or EM)?

As a side note, I tried also fine-tuning via the provided napari plugin, but I stumble upon a different error:

“The path to the raw data is missing or does not exist. The path to label data is missing or does not exist.”

I wonder what the actual underlying error is, because the path obviously exists, as I can load and display the images and annotation masks in the fine-tuning notebook.

anwai98 commented 1 month ago

Hi @rodrigo-pena,

Could you share with us the entire error trace from the finetuning notebook? (I think I have a suspicion on what's happening - torch-em expects the inputs to be channels first, however would be good to verify this by looking at the trace)

In addition, could you attach a screenshot of the finetuning GUI widget in napari, where you initialize the parameters for finetuning?

rodrigo-pena commented 1 month ago

Hi @anwai98, here's the code cell I run:

batch_size = 1  # the training batch size
patch_shape = (512, 512)  # the size of patches for training

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
train_instance_segmentation = True

# Label transform is used to convert the ground-truth labels to the desired instances for finetuning Segment Anything.
# Or to learn the foreground and distances to the object centers and object boundaries for automatic segmentation.
if train_instance_segmentation:
    # Computes the distance transform for objects to jointly perform the additional decoder-based automatic 
    # instance segmentation (AIS) and finetune Segment Anything.
    label_transform = PerObjectDistanceTransform(
        distances=True,
        boundary_distances=True,
        directed_distances=False,
        foreground=True,
        instances=True,
        min_size=25,
    )
else:
    # Ensures the individual object instances to finetune the classical Segment Anything.
    label_transform = torch_em.transform.label.connected_components

train_loader = torch_em.default_segmentation_loader(
    raw_paths=image_dir,
    raw_key=raw_key,
    label_paths=segmentation_dir,
    label_key=label_key,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=True,
    rois=train_roi,
    label_transform=label_transform,
    shuffle=True,
    raw_transform=sam_training.identity,
)
val_loader = torch_em.default_segmentation_loader(
    raw_paths=image_dir,
    raw_key=raw_key,
    label_paths=segmentation_dir,
    label_key=label_key,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=True,
    rois=val_roi,
    label_transform=label_transform,
    shuffle=True,
    raw_transform=sam_training.identity,
)

And here's the corresponding the full error traceback:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[6], line 24
     20 else:
     21     # Ensures the individual object instances to finetune the classical Segment Anything.
     22     label_transform = torch_em.transform.label.connected_components
---> 24 train_loader = torch_em.default_segmentation_loader(
     25     raw_paths=image_dir,
     26     raw_key=raw_key,
     27     label_paths=segmentation_dir,
     28     label_key=label_key,
     29     patch_shape=patch_shape,
     30     batch_size=batch_size,
     31     ndim=2,
     32     is_seg_dataset=True,
     33     rois=train_roi,
     34     label_transform=label_transform,
     35     shuffle=True,
     36     raw_transform=sam_training.identity,
     37 )
     38 val_loader = torch_em.default_segmentation_loader(
     39     raw_paths=image_dir,
     40     raw_key=raw_key,
   (...)
     50     raw_transform=sam_training.identity,
     51 )

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:212](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=211), in default_segmentation_loader(raw_paths, raw_key, label_paths, label_key, batch_size, patch_shape, label_transform, label_transform2, raw_transform, transform, dtype, label_dtype, rois, n_samples, sampler, ndim, is_seg_dataset, with_channels, with_label_channels, verify_paths, **loader_kwargs)
    189 def default_segmentation_loader(
    190     raw_paths,
    191     raw_key,
   (...)
    210     **loader_kwargs,
    211 ):
--> 212     ds = default_segmentation_dataset(
    213         raw_paths=raw_paths,
    214         raw_key=raw_key,
    215         label_paths=label_paths,
    216         label_key=label_key,
    217         patch_shape=patch_shape,
    218         label_transform=label_transform,
    219         label_transform2=label_transform2,
    220         raw_transform=raw_transform,
    221         transform=transform,
    222         dtype=dtype,
    223         label_dtype=label_dtype,
    224         rois=rois,
    225         n_samples=n_samples,
    226         sampler=sampler,
    227         ndim=ndim,
    228         is_seg_dataset=is_seg_dataset,
    229         with_channels=with_channels,
    230         with_label_channels=with_label_channels,
    231         verify_paths=verify_paths,
    232     )
    233     return get_data_loader(ds, batch_size=batch_size, **loader_kwargs)

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:274](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=273), in default_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, patch_shape, label_transform, label_transform2, raw_transform, transform, dtype, label_dtype, rois, n_samples, sampler, ndim, is_seg_dataset, with_channels, with_label_channels, verify_paths)
    269     transform = _get_default_transform(
    270         raw_paths if isinstance(raw_paths, str) else raw_paths[0], raw_key, is_seg_dataset, ndim
    271     )
    273 if is_seg_dataset:
--> 274     ds = _load_segmentation_dataset(
    275         raw_paths,
    276         raw_key,
    277         label_paths,
    278         label_key,
    279         patch_shape=patch_shape,
    280         raw_transform=raw_transform,
    281         label_transform=label_transform,
    282         label_transform2=label_transform2,
    283         transform=transform,
    284         rois=rois,
    285         n_samples=n_samples,
    286         sampler=sampler,
    287         ndim=ndim,
    288         dtype=dtype,
    289         label_dtype=label_dtype,
    290         with_channels=with_channels,
    291         with_label_channels=with_label_channels,
    292     )
    293 else:
    294     ds = _load_image_collection_dataset(
    295         raw_paths,
    296         raw_key,
   (...)
    308         label_dtype=label_dtype,
    309     )

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py:102](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/segmentation.py#line=101), in _load_segmentation_dataset(raw_paths, raw_key, label_paths, label_key, **kwargs)
    100         if isinstance(rois, tuple):
    101             assert all(isinstance(roi, slice) for roi in rois)
--> 102     ds = SegmentationDataset(raw_paths, raw_key, label_paths, label_key, roi=rois, **kwargs)
    103 else:
    104     assert len(raw_paths) > 0

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/data/segmentation_dataset.py:66](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/39288/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/data/segmentation_dataset.py#line=65), in SegmentationDataset.__init__(self, raw_path, raw_key, label_path, label_key, patch_shape, raw_transform, label_transform, label_transform2, transform, roi, dtype, label_dtype, n_samples, sampler, ndim, with_channels, with_label_channels)
     64 shape_raw = self.raw.shape[1:] if self._with_channels else self.raw.shape
     65 shape_label = self.labels.shape[1:] if self._with_label_channels else self.labels.shape
---> 66 assert shape_raw == shape_label, f"{shape_raw}, {shape_label}"
     68 self.shape = shape_raw
     69 self.roi = roi

AssertionError: (17, 1040, 1392, 3), (17, 1040, 1392)
rodrigo-pena commented 1 month ago

I'm also adding the screenshot from the GUI and the error I get when I try to fine-tune the vit_l_lm model via it (I'm censoring parts of the paths to avoid disclosing locations in my HPC system). The error appears right after I click the "Start Training" button

Screenshot 2024-09-25 at 12 06 15
anwai98 commented 1 month ago

Thanks for sharing the trace and the GUI screenshot.

Re: notebook error trace: Could you also provide information about the paths you provide as inputs (would help me understand the nature of inputs you are trying to pass), i.e.:

Re: GUI finetuning screenshot: I think one thing you are missing is passing values to Image data key and Label data key, which would search the inputs for you and collect them for creating the dataloaders. (I think this question would probably be answered once I get answers to the aforementioned questions)

rodrigo-pena commented 1 month ago

Here's the cell where I define raw_key and label_key:

# Load images from multiple files in folder via pattern (here: all tif files)
raw_key, label_key = "*.tif", "*.tif"

# Alternative: if you have tif stacks you can just set raw_key and label_key to None
# raw_key, label_key = None, None

# The 'roi' argument can be used to subselect parts of the data.
# Here, we use it to select the first 85% of images for the training split and the other frames for the validation split.
n_images = len(image_paths)
n_train = int(0.85 * n_images )
train_roi = np.s_[:n_train, :, :]
val_roi = np.s_[n_train:, :, :]

Indeed, image_dir points to /.../train_images and label_dir to /.../train_labels. And all images in there are .tif files:

DATA_FOLDER = os.path.join(HOME, 'data', 'myocount-validation-data')
image_dir = os.path.join(DATA_FOLDER, 'train_images')
segmentation_dir = os.path.join(DATA_FOLDER, 'train_labels')

Regarding your second bullet point, I don't know if we can interpret the images as 3d. They have indeed 3 channels (one for the nuclei, one for the cytoplasm, and a dummy one), but each channel has very different information. So these images are different than one would get in e.g., z-stacks, which I would consider 3d data.

As for the GUI screenshot, thanks for catching the missing ".tif" keys. However, I tried again (this time properly setting ".tif" under "Image data key" and "Label data key") and the error remains the same

anwai98 commented 1 month ago

Okay, I can reproduce the error and I think I know what's causing this. It's the way we are fetching the multi-channel inputs, which are getting stacked together and hence causing the issue. Let's fix the issues one-by-one.

For the notebook-based finetuning on your provided data:

OPTION 1:

We are currently using torch_em-supported SegmentationDataset. This expects to change the inputs to have channels first (i.e. the input images currently are in the structure of (1040, 1392, 3). We need to change them to (3, 1040, 1392)). To do this, you could do the following:

from pathlib import Path
import imageio.v3 as imageio

image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
for image_path in image_paths:
    pxpath  = Path(image_path)
    target_path = os.path.join(pxpath.parent, "image_preprocessed_dir", pxpath.name)

    image = imageio.imread(image_path)
    image = image.transpose(2, 0, 1)
    imageio.imwrite(target_path, image)

Next, we need to update the structure in which we pass the inputs to the dataloader:

image_dir = ...  # NOTE: this needs to point to "image_preprocessed_dir" now
# Fetching all inputs in respective directories
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(labels_dir, "*.tif")))

# To make valid splits for the inputs
n_images = len(image_paths)
n_train = int(0.85 * n_images)

# Lets's split the input paths
train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]

# We create our dataloaders
train_loader = torch_em.default_segmentation_loader(
    raw_paths=train_image_paths,
    raw_key=None,
    label_paths=train_label_paths,
    label_key=None,
    batch_size=1,
    patch_shape=(512, 512),
    ndim=2,
    is_seg_dataset=True,
    with_channels=True,
)
val_loader = ...  # same snippet as above with "val"-related paths

OPTION 2:

We can choose to leave the inputs as it is and use another torch_em-supported dataset (ImageCollectionDataset) to fetch inputs:

image_dir = ...  # NOTE: this needs to point to old image directory here
# Fetching all inputs in respective directories
image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(labels_dir, "*.tif")))

# To make valid splits for the inputs
n_images = len(image_paths)
n_train = int(0.85 * n_images)

# Lets's split the input paths
train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]

# We create our dataloaders
train_loader = torch_em.default_segmentation_loader(
    raw_paths=train_image_paths,
    raw_key=None,
    label_paths=train_label_paths,
    label_key=None,
    batch_size=1,
    patch_shape=(512, 512),
    ndim=2,
    is_seg_dataset=False,
    n_samples=50,  # This oversamples the inputs, i.e. reruns the dataset object over the inputs (once it has gone through all the images) in order to to fetch more patches
)
val_loader = ...  # same snippet as above with "val"-related paths

PS. I haven't tested the code written here, but in theory should work out. Let us know how it goes.

anwai98 commented 1 month ago

Ah and regarding the issue with GUI-based finetuning, I see that the only checks we do in the code is to verify if the provided path exists or not.

Could you double check two things: a) If the paths exist for both label and image directory via your terminal using (just to be doubly sure): 'python -c 'import os; print(os.path.exists("<PATHS_FROM_GUI>"))' b) Could you also report the following: python -c 'import micro_sam; print(micro_sam.__version__)'

rodrigo-pena commented 1 month ago

Thanks for the options, I'll try them tomorrow and report the results here. Does the napari plugin "Annotator 2D" use a different dataloader under the hood? After all, I was able to compute embeddings and run the pre-trained vit_l_lm model on it to assist with the annotations?

About the points to double check re: fine-tuning GUI: a) Returns True b) Returns 1.0.1

anwai98 commented 1 month ago

Does the napari plugin "Annotator 2D" use a different dataloader under the hood?

The "Annotator 2d" just takes the input path for the image and opens it up the array in a layer. It does not use the dataloader schema.

After all, I was able to compute embeddings and run the pre-trained vit_l_lm model on it to assist with the annotations?

Yes, it's because of the aforementioned reason. The dataloaders are a bit specific to the finetuning schema, but with the recommended options to adapt, they should work now. (I will look into increasing support for different image types in our dataloader tutorial notebook)

About the points to double check re: fine-tuning GUI: a) Returns True b) Returns 1.0.1

Thanks for checking this out. Okay, it seems that the directories are visible to the model. Hmm, that's strange. Can you confirm if you have made the installation from source / conda-forge?

rodrigo-pena commented 1 month ago

I have created a separate mamba environment for micro-sam and installed the package in it via mamba using the conda-forge channel (plus pytorch with GPU capability), according to these instructions

rodrigo-pena commented 1 month ago

Hi @anwai98 , I've tried Options 1 and 2 that you gave me above, and both seem to work in creating the dataloaders. But it is still not clear for me why in one options I have specify the flag with_channels=True and in the other n_samples=50. Are these documented somewhere?

Regardless, there are some differences in the created dataloaders: when using torch_em.util.debug.check_loader to display some images, the loader from Option 1 raises a warning ("Multi-channel input data is not yet supported, will only show channel 0"), while the one from Option 2 displays images without warning.

Moving on, when I run the following training code

# All hyperparameters for training.
n_objects_per_batch = 5  # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
n_epochs = 10  # how long we train (in epochs)

# The model_type determines which base model is used to initialize the weights that are finetuned.
# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.
model_type = "vit_l_lm"

# The name of the checkpoint. The checkpoints will be stored in './checkpoints/<checkpoint_name>'
checkpoint_name = "vit_l_lm_finetuned"

# Run training
sam_training.train_sam(
    name=checkpoint_name,
    save_root=os.path.join(HOME, "models"),
    model_type=model_type,
    checkpoint_path=os.path.join(HOME, "models", "vit_l_lm", "vit_l.pt"),
    train_loader=train_loader,
    val_loader=val_loader,
    n_epochs=n_epochs,
    n_objects_per_batch=n_objects_per_batch,
    with_segmentation_decoder=train_instance_segmentation,
    device=device,
)

I get the error traceback

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[29], line 2
      1 # Run training
----> 2 sam_training.train_sam(
      3     name=checkpoint_name,
      4     save_root=os.path.join(HOME, "models"),
      5     model_type=model_type,
      6     checkpoint_path=os.path.join(HOME, "models", "vit_l_lm", "vit_l.pt"),
      7     train_loader=train_loader,
      8     val_loader=val_loader,
      9     n_epochs=n_epochs,
     10     n_objects_per_batch=n_objects_per_batch,
     11     with_segmentation_decoder=train_instance_segmentation,
     12     device=device,
     13 )

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:184](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=183), in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
    128 def train_sam(
    129     name: str,
    130     model_type: str,
   (...)
    148     pbar_signals: Optional[QObject] = None,
    149 ) -> None:
    150     """Run training for a SAM model.
    151 
    152     Args:
   (...)
    182         pbar_signals: Controls for napari progress bar.
    183     """
--> 184     _check_loader(train_loader, with_segmentation_decoder)
    185     _check_loader(val_loader, with_segmentation_decoder)
    187     device = get_device(device)

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:80](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=79), in _check_loader(loader, with_segmentation_decoder)
     74 if n_channels_y != 4:
     75     raise ValueError(
     76         "Invalid number of channels in the target data from the data loader. "
     77         "Expect 4 channel for training with an instance segmentation decoder, "
     78         f"but got {n_channels_y} channels."
     79     )
---> 80 check_instance_channel(y[:, 0])
     82 targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()
     83 if targets_min < 0 or targets_min > 1:

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:64](https://ood-ubuntu.scicore.unibas.ch/node/sgi01.deploy.int/35541/lab/tree/myotube-characterization/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=63), in _check_loader.<locals>.check_instance_channel(instance_channel)
     60     raise ValueError(
     61         "The target channel with the instance segmentation must not have negative values."
     62     )
     63 if len(unique_vals) == 1:
---> 64     raise ValueError(
     65         "The target channel with the instance segmentation must have at least one instance."
     66     )
     67 if not torch.allclose(unique_vals, unique_vals.round(), atol=1e-7):
     68     raise ValueError(
     69         "All values in the target channel with the instance segmentation must be integer."
     70     )

ValueError: The target channel with the instance segmentation must have at least one instance.

I think this might be because some of my annotation masks are all zeros, meaning that no object should be segmented in that image. These all-zero masks were obtained by "skipping" the image in "Image Series Annotator" (i.e., pressing the Next Image button). If the error is because of that, does it mean that I have to manually discard from my dataset all images whose corresponding segmentation masks are all zeros?

anwai98 commented 1 month ago

Thanks for confirming that the dataloaders now work.

But it is still not clear for me why in one options I have specify the flag with_channels=True and in the other n_samples=50. Are these documented somewhere?

It's the heuristic how we create and handle dataloaders in torch-em (if you look closely, we have is_seg_dataset as different values for both options). What option 1 uses SegmentationDataset, which fetches valid patches (say, of desired patch shape (512, 512)) multiple times from an image larger than the desired patch shape. The option 2 uses ImageCollectionDataset, which fetches only one randomly cropped patch per image (so let's say you have 3 training images, you will have 3 patches extracted in total) - which is why we increase n_samples to a higher integer, so that we run the dataset object more times on the entire data to get valid patches.

We have an introductory documentation for torch-em datasets here: https://constantinpape.github.io/torch-em/ (we are looking into extending this in the near future). We would be happy to answer any questions from your side on this.

Regardless, there are some differences in the created dataloaders: when using torch_em.util.debug.check_loader to display some images, the loader from Option 1 raises a warning ("Multi-channel input data is not yet supported, will only show channel 0"), while the one from Option 2 displays images without warning.

My suspicion is that happens because matplotlib does not support channels-first plotting of inputs. I can reproduce the error in both cases (it's because the inputs are coming out as tensors with channels-first in either of the two recommended options), but as long as your images visually look fine, the warning is not relevant.

You can check if your loader outputs from both look as expected:

loader = ...
inputs = next(iter(loader))
print(inputs[0].shape)  # the inputs should look like: torch.Size([1, 3, 512, 512])

If the error is because of that, does it mean that I have to manually discard from my dataset all images whose corresponding segmentation masks are all zeros?

Yes, exactly. The error notifies you that the provided labels do not have any valid objects to perform finetuning. You can use a sampler for this case supported by torch-em:

# the provided code with it's default values checks if your labels have atleast one valid object to segment
from torch_em.data.sampler import MinInstanceSampler
loader = torch_em.default_segmentation_loader(
    ...,  # all other arguments
    sampler=MinInstanceSampler(),
)

Let us know if this makes the finetuning work.

rodrigo-pena commented 1 month ago

Update: I have trimmed the dataset to contain only image-mask pairs with at least one instance, and I've used the MinInstanceSampler for the Option 2 dataloader.

However, the sam_training.train_sam(...) call now gives me the following new error traceback:

[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:60](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch/optim/lr_scheduler.py#line=59): UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate.
  warnings.warn(
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:76](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=75): FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = amp.GradScaler() if mixed_precision else None
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py:177](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py#line=176): UserWarning: Constructor arguments for <class 'micro_sam.training.trainable_sam.TrainableSAM'> cannot be deduced.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
  warnings.warn(
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py:177](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/util/util.py#line=176): UserWarning: Constructor arguments for <class 'torch_em.model.unetr.UNETR'> cannot be deduced.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
  warnings.warn(

Start fitting for 150 iterations /  10 epochs
with 15 iterations per epoch
Training with mixed precision

Epoch 0:   0%|          | 0[/150](<redacted_url>/150) [00:00<?, ?it[/s](<redacted_url>/s)][~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:86](<redacted_url>/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=85): FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with forward_context():
[~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:95](~/notebooks/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=94): FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  with forward_context():
Epoch 0:   2%|▏         | 3[/150](<redacted_url>/150) [00:06<05:33,  2.27s[/it](<redacted_url>/it)]

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[15], line 2
      1 # Run training
----> 2 sam_training.train_sam(
      3     name=checkpoint_name,
      4     save_root=os.path.join(HOME, "models"),
      5     model_type=model_type,
      6     checkpoint_path=os.path.join(HOME, "models", "vit_b_lm", "vit_b.pt"),
      7     train_loader=train_loader_from_segmentation_dataset,
      8     # train_loader=train_loader_from_image_collection_dataset,
      9     val_loader=val_loader_from_segmentation_dataset,
     10     # val_loader=val_loader_from_image_collection_dataset,
     11     n_epochs=n_epochs,
     12     n_objects_per_batch=n_objects_per_batch,
     13     with_segmentation_decoder=train_instance_segmentation,
     14     device=device,
     15 )

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py:282](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/training.py#line=281), in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
    279     progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
    280     trainer_fit_params["progress"] = progress_bar_wrapper
--> 282 trainer.fit(**trainer_fit_params)

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:583](<redacted_url>notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=582), in DefaultTrainer.fit(self, iterations, load_from_checkpoint, epochs, save_every_kth_epoch, progress)
    580     pass
    582 # Run training and validation for this epoch
--> 583 t_per_iter = train_epoch(progress)
    584 current_metric = validate()
    586 # perform all the post-epoch steps:
    587 
    588 # apply the learning rate scheduler

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py:647](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/torch_em/trainer/default_trainer.py#line=646), in DefaultTrainer._train_epoch_mixed(self, progress)
    646 def _train_epoch_mixed(self, progress):
--> 647     return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed)

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py:89](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/joint_sam_trainer.py#line=88), in JointSamTrainer._train_epoch_impl(self, progress, forward_context, backprop)
     84 self.optimizer.zero_grad()
     86 with forward_context():
     87     # 1. train for the interactive segmentation
     88     (loss, mask_loss, iou_regression_loss, model_iou,
---> 89      sampled_binary_y) = self._interactive_train_iteration(x, labels_instances)
     91 backprop(loss)
     93 self.optimizer.zero_grad()

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/sam_trainer.py:296](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/sam_trainer.py#line=295), in SamTrainer._interactive_train_iteration(self, x, y)
    293 def _interactive_train_iteration(self, x, y):
    294     n_pos, n_neg, get_boxes, multimask_output = self._get_prompt_and_multimasking_choices(self._iteration)
--> 296     batched_inputs, sampled_ids = self.convert_inputs(x, y, n_pos, n_neg, get_boxes, self.n_objects_per_batch)
    297     batched_inputs, y_one_hot = self._preprocess_batch(batched_inputs, y, sampled_ids)
    299     loss, mask_loss, iou_regression_loss, model_iou = self._compute_iterative_loss(
    300         batched_inputs, y_one_hot,
    301         num_subiter=self.n_sub_iteration, multimask_output=multimask_output
    302     )

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py:172](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py#line=171), in ConvertToSamInputs.__call__(self, x, y, n_pos, n_neg, get_boxes, n_samples)
    170 for image, gt in zip(x, y):
    171     gt = gt.squeeze().numpy().astype(np.int64)
--> 172     box_prompts, point_prompts, point_label_prompts, sampled_cell_ids = self._get_prompt_lists(
    173         gt, n_samples, prompt_generator,
    174     )
    176     # check to be sure about the expected size of the no. of elements in different settings
    177     if get_boxes:

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py:144](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/training/util.py#line=143), in ConvertToSamInputs._get_prompt_lists(self, gt, n_samples, prompt_generator)
    141     bbox_coordinates = self._distort_boxes(bbox_coordinates, shape=gt.shape[-2:])
    143 # convert the gt to the one-hot-encoded masks for the sampled cell ids
--> 144 object_masks = segmentation_to_one_hot(gt, None if n_samples is None else sampled_cell_ids)
    146 # derive and return the prompts
    147 point_prompts, point_label_prompts, box_prompts, _ = prompt_generator(object_masks, bbox_coordinates)

File [~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/util.py:950](<redacted_url>/notebooks/~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/micro_sam/util.py#line=949), in segmentation_to_one_hot(segmentation, segmentation_ids)
    947     n_ids = int(segmentation.max())
    949 else:
--> 950     assert segmentation_ids[0] != 0, "No objects were found."
    952     # the segmentation ids have to be sorted
    953     segmentation_ids = np.sort(segmentation_ids)

IndexError: index 0 is out of bounds for axis 0 with size 0

For reference, the dummy dataset with only valid image-mask pairs looks like this (with the first three images going to the training split and the last image going to the validation split):

Screenshot 2024-09-26 at 16 09 18

Remark: I don't know if one of the warnings have to do with this, but the cluster from which I'm running this has no access to the Internet, so I have to pre-load the model checkpoints to train from. In sam_training.train_sam(...) I point to the vit_b_lm checkpoint via the file vit_l.pt, but I don't know if and how I should point to the corresponding pre-trained decoder vit_b_decoder.pt

anwai98 commented 1 month ago

Re: warnings in training: I would say you can ignore them, they aren't anything critical.

Re: error in training: Hmm, seems like we are still getting samples with no foreground objects. That's strange. Can you visually validate from the dataloader outputs if that's the case, and report if you still see images without valid paired labels (remember to use sampler=MinInstanceSampler() in the torch_em.default_segmentation_dataset)?

(you can visualize the loaders locally using napari with the script below:

train_loader = ...
val_loader = ...

from torch_em.util.debug import check_loader
# NOTE: I choose a large number of samples below to visualize all possible inputs
check_loader(train_loader, 50)
check_loader(val_loader, 50)

Re: cluster access to internet: I have a similar setup at my end. To tackle this, what I do is briefly run the scripts on the login node (which does have access to the internet), and while I see the training proceed (after 1/2 iterations), I just stop the job and submit it to the cluster. This would be my recommendation, as it ensures all the necessary automatic downloads (i.e. model checkpoints), and then you do not need to hard-code our model paths. (+ this would also make sure to load the decoder weights)

rodrigo-pena commented 1 month ago

Following your suggestion, I've checked the data loader outputs on napari.

For Option 1, I saw that initially there were some rare image patches that were cropped from areas with no instance segmentation. So I used sampler=MinInstanceSampler() for its data loaders, and that seemed to visually solve the problem on the samples I've explored. However, when running training the same error appeared ( IndexError: index 0 is out of bounds for axis 0 with size 0)

For Option 2 (which already had sampler=MinInstanceSampler()), I could see sample patches on napari of images that were still cropped from an area without annotated segmentation masks. The training fails in the same way with Option 2.

There is one thing that still may be an issue: only two of the channels in the images are informative (two fluorescence wavelengths). The third is all-zeros, put there by the creators of the dataset so that maybe the images could be read as RGB. Could that be causing the issue?

rodrigo-pena commented 1 month ago

Actually, I just ran the check_loader on a bunch of samples from the Option 2 dataloader and could only see good pairs, so maybe I mistook what I saw the first time round (or it was some fluke)

anwai98 commented 1 month ago

Could you please double check this for Option 2 with the sampler? (I would recommend to stick with Option 2 as this works and is a bit more simpler compared to changing axes around)

If it still causes issues, we might need to investigate this and fix this in torch-em.

In addition, could you report the output of the following: python -c "import torch_em; print(torch_em.__version__)"

rodrigo-pena commented 1 month ago

Good morning, @anwai98. By the way, I appreciate this level of support. It's great to see and really motivates me to use the code base y'all developed.

Focusing on Option 2, I re-ran the training code today and this time it trained for 1.25 epochs before crashing with the same IndexError as before.

So once again I inspected the image/label pairs coming from the training and validation data loaders on napari through the check_loaders function. Out of the 100 images i saw, they were all good pairs, except for 1. I took a screenshot of the thumbnails of the bad pair loaded on napari:

Screenshot 2024-09-27 at 09 47 32

As you can see, it seems that the segmentation mask is all-zeros. Weirdly enough, though, if I split the stack into three channels, two of the channels in the segmentation mask are all-ones:

Screenshot 2024-09-27 at 09 48 16

I zoomed into the extracted image patch,

Screenshot 2024-09-27 at 09 49 02

to search for which original image it had been extracted. I ended up finding it (red square, rotated 180 degrees):

Screenshot 2024-09-27 at 09 49 35

As you can see, it is indeed a patch with no annotated segmentation instances. So I wonder how the MinInstanceSampler() is letting this pass.

This behavior seems to me uncommon enough that sometimes the trainer will go through a whole epoch without seeing such bad pairs, as stated in the beginning. But it will still happen sometime during training, leading to a crash.

anwai98 commented 1 month ago

Hi @rodrigo-pena,

Thanks for getting back with the detailed feedback.

Re: sampler misses to ignore empty labels for one patch: I remember we recently fixed this in torch-em. The first step here would be to check if we have the latest torch-em installed. Could you run python -c "import torch_em; print(torch_em.__version__)" and share the output with us?

rodrigo-pena commented 1 month ago

This command returns

~/miniconda3/envs/micro-sam/lib/python3.12/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
0.7.3
anwai98 commented 1 month ago

Hmm, you do have the version after we made the fix (I double-checked, and we already made the changes in 0.6.2). I'll try to reproduce this over sparse labels myself and see what's going on.

(and I will also come back with a bit more theoretical feedback, eg. using dense labels for training automatic instance segmentation, the question of usage of channels in your case, etc. later once we take care of the above technical issue)

anwai98 commented 1 month ago

I tried to reproduce the mentioned effect on sparse labels and I have a few suspicions which might potentially get rid of this issue (in short: there might be tiny pixels getting considered as valid objects, which are passed to the model and might be somehow causing the issue).

Could you try again by adding an additional argument to the sampler we pass to the loader: sampler=MinInstanceSampler(min_size=25)?

NOTE: Remember to use the sampler in both training and validation loaders!

EDIT: I forgot to mention something. For this, you need to install torch-em from source. You can just do this in your current working environment via terminal using:

git clone https://github.com/constantinpape/torch-em
cd torch-em
pip install -e . --no-deps
rodrigo-pena commented 1 month ago

Ok, that did the trick. Training has finished and the resulting AIS predictions look promising:

Screenshot 2024-09-27 at 15 38 46

However, during the process of re-installing some packages, something broke in imageio and and I couldn't write tiffs to file anymore. Here are the steps I took:

  1. Delete my old micro-sam environment.
  2. Create a new one according to the mamba installation instructions, to have napari pinned to <0.5 (related to the other issue #695 )
  3. Install torch-em from source according to your previous comment.

After these 3 steps, whenever I called image_io.imwrite() I got the error AttributeError: 'TiffWriter' object has no attribute 'write'. I then checked the tifffile package version and it was a very old one: 2020.6.3. Back then, TiffWriter didn't have indeed a 'write'; it was called 'save' instead. I had to force pip install tifffile --upgrade to then fix the imageio error. Could the reason for this mismatch be the --no-deps flag when installing torch-em? I don't know what other functionality may be affected by further dependency mismatches.

anwai98 commented 1 month ago

Yayy, that's great to hear that the training works now.

Re: broken imageio

Could the reason for this mismatch be the --no-deps flag when installing torch-em?

This should not be responsible for breaking an existing dependency, because the idea of using --no-deps is exactly to avoid fetching additional dependencies. It might have broken while downgrading napari, but good to know that upgrading tifffile works.

anwai98 commented 1 month ago

And a few more fundamental points to share (which in theory are good to work with micro-sam for best results):

PS. We are happy to have all the feedback from you. We have a better idea now of the documentation and notebooks improvements we want to make to improve user experience. Thanks for your patience and interest in micro-sam.

Let us know if there's something else you would like to discuss.

rodrigo-pena commented 1 month ago

For a summary, then, the solution that seems to work requires:

  1. Installing micro-sam from conda-forge
  2. Installing torch-em from source (with dependencies, to be sure)
  3. Removing from the training dataset any image/mask pair for which the mask is all-zeros (i.e., has no segmentation instances other than the background.
  4. Create the dataloaders using torch_em.default_segmentation_loader with arguments:
    • is_seg_dataset=False to induce creating an ImageCollectionDataset
    • n_samples=N, where N is an appropriately large number to allow a representative collection of random patches from the training images
    • sampler=MinInstanceSampler(min_size=M), where M is an appropriately large number to avoid segmentation regions that are too small and could lead to errors

I just wanted to make sure I understand the concept of "dense annotations": does this mean annotating everything that should be segmented in the image, or does that mean having the segmented instances occupy a fraction of the image similar to the background?

Where do you plan to ad the improvements to the documentation? on the example finetuning notebooks of micro-sam or in the example notebooks of torch-em?

anwai98 commented 1 month ago

Thanks for the nice summary. I'll leave a few mentions:

  1. Installing micro-sam from conda-forge
  2. Installing torch-em from source (with dependencies, to be sure)

That's correct. To add to this, we are planning to make a release rather soon which would merge point 2. into 1.. Then, ideally a user would only need to install micro-sam (from conda-forge / from source) and things should work as expected.

  1. Removing from the training dataset any image/mask pair for which the mask is all-zeros (i.e., has no segmentation instances other than the background.

This is the reason we use sampler in torch_em.default_segmentation_loader. As long as you have the sampler, you don't need to remove empty label/image pairs. The sampler should take care of it.

And Re: dataloaders:

sampler=MinInstanceSampler(min_size=M), where M is an appropriately large number to avoid segmentation regions that are too small and could lead to errors

This should match the value to the minimum number of pixels an object should have to be considered as valid for finetuning. Also, this should match with the min_size value in label_transform.

Where do you plan to ad the improvements to the documentation? on the example finetuning notebooks of micro-sam or in the example notebooks of torch-em?

I would like to broaden the dataloader tutorials in torch-em notebooks and update micro-sam notebook to add notes and support for samplers.

EDIT:

I just wanted to make sure I understand the concept of "dense annotations": does this mean annotating everything that should be segmented in the image, or does that mean having the segmented instances occupy a fraction of the image similar to the background?

Yes. I would recommend to fully annotate desired objects as much as possible (ideally all objects) to get the best outcome with our automatic instance segmentation method.

rodrigo-pena commented 1 month ago

Hi @anwai98,

This is the reason we use sampler in torch_em.default_segmentation_loader. As long as you have the sampler, you don't need to remove empty label/image pairs. The sampler should take care of it.

This does not seem to be the case. I just tried running my training code on the full dataset, including the images with blank label pairs and I get the following error:

Start fitting for 500 iterations /  10 epochs
with 50 iterations per epoch
Training with mixed precision

Epoch 0:   1%|          | 6/500 [00:20<22:55,  2.78s/it]

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[17], line 2
      1 # Run training
----> 2 sam_training.train_sam(
      3     name=checkpoint_name,
      4     save_root=os.path.join(HOME, "models"),
      5     model_type=model_type,
      6     checkpoint_path=os.path.join(HOME, "models", "vit_b_lm", "vit_b.pt"),
      7     # train_loader=train_loader_from_segmentation_dataset,
      8     train_loader=train_loader_from_image_collection_dataset,
      9     # val_loader=val_loader_from_segmentation_dataset,
     10     val_loader=val_loader_from_image_collection_dataset,
     11     n_epochs=n_epochs,
     12     n_objects_per_batch=n_objects_per_batch,
     13     with_segmentation_decoder=train_instance_segmentation,
     14     device=device,
     15 )

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/micro_sam/training/training.py:282, in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
    279     progress_bar_wrapper = _ProgressBarWrapper(pbar_signals)
    280     trainer_fit_params["progress"] = progress_bar_wrapper
--> 282 trainer.fit(**trainer_fit_params)

File ~/torch-em/torch_em/trainer/default_trainer.py:586, in DefaultTrainer.fit(self, iterations, load_from_checkpoint, epochs, save_every_kth_epoch, progress)
    583     pass
    585 # Run training and validation for this epoch
--> 586 t_per_iter = train_epoch(progress)
    587 current_metric = validate()
    589 # perform all the post-epoch steps:
    590 
    591 # apply the learning rate scheduler

File ~/torch-em/torch_em/trainer/default_trainer.py:650, in DefaultTrainer._train_epoch_mixed(self, progress)
    649 def _train_epoch_mixed(self, progress):
--> 650     return self._train_epoch_impl(
    651         progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"),
    652         self._backprop_mixed
    653     )

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/micro_sam/training/joint_sam_trainer.py:78, in JointSamTrainer._train_epoch_impl(self, progress, forward_context, backprop)
     76 n_iter = 0
     77 t_per_iter = time.time()
---> 78 for x, y in self.train_loader:
     79     labels_instances = y[:, 0, ...].unsqueeze(1)
     80     labels_for_unetr = y[:, 1:, ...]

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
    671 def _next_data(self):
    672     index = self._next_index()  # may raise StopIteration
--> 673     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    674     if self._pin_memory:
    675         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:52, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     50         data = self.dataset.__getitems__(possibly_batched_index)
     51     else:
---> 52         data = [self.dataset[idx] for idx in possibly_batched_index]
     53 else:
     54     data = self.dataset[possibly_batched_index]

File ~/miniconda3/envs/micro-sam/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:52, in <listcomp>(.0)
     50         data = self.dataset.__getitems__(possibly_batched_index)
     51     else:
---> 52         data = [self.dataset[idx] for idx in possibly_batched_index]
     53 else:
     54     data = self.dataset[possibly_batched_index]

File ~/torch-em/torch_em/data/image_collection_dataset.py:194, in ImageCollectionDataset.__getitem__(self, index)
    193 def __getitem__(self, index):
--> 194     raw, labels = self._get_sample(index)
    195     initial_label_dtype = labels.dtype
    197     if self.raw_transform is not None:

File ~/torch-em/torch_em/data/image_collection_dataset.py:185, in ImageCollectionDataset._get_sample(self, index)
    182             raw, label, shape, prefix_box, have_raw_channels = self._load_data(raw_path, label_path)
    184         if sample_id > self.max_sampling_attempts:
--> 185             raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
    187 # to channel first
    188 if have_raw_channels and len(prefix_box) == 0:

RuntimeError: Could not sample a valid batch in 500 attempts

Then, when I train only with the "valid" image-label pairs everything goes smoothly.

It seems to me that the MinInstanceSampler needs at least one valid segmentation instance in the image to work.

anwai98 commented 1 month ago

Hi @rodrigo-pena,

RuntimeError: Could not sample a valid batch in 500 attempts

The reported issue originates when the sampler cannot find valid samples after "n" attempts. This can be fixed by increasing the sampling attempts for fetching a valid sample after creating the dataloaders as follows:

train_loader = ...
val_loader = ...

print(train_loader.dataset.max_sampling_attempts)  # you should see the current default output as 500, i.e. torch-em will try to run the sampler 500 times in search of a valid input, else throw an error.

train_loader.dataset.max_sampling_attempts = 5000
val_loader.dataset.max_sampling_attempts = 5000

# let's verify whether our argument has been updated
print(train_loader.dataset.max_sampling_attempts)  # the output of this line should show the number we increased the attempts to, i.e. 5000 

Let us know if this works for you.

rodrigo-pena commented 1 month ago

Training seems to be working now for the full dataset after setting max_sampling_attempts = 5000. I'm still puzzled as to why I didn't see this issue show up when I was using only images with at least one segmentation instance. Do you have any idea? Was it pure chance?

anwai98 commented 1 month ago

Training seems to be working now for the full dataset

Happy to hear that the training works now.

why I didn't see this issue show up when I was using only images with at least one segmentation instance.

My guess is when there are no empty labelled sample-pairs, the sampler might be getting a valid sample in under 500 attempts.

jalexs82 commented 1 month ago

I'm also working with RGB data, but I'm running into a slightly different problem. I have a folder with RGB images in .tif format (reshaped so that the RGB channel is first). I have set up the data loader using both option 1 and 2. Here is my version of option 2:

image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(segmentation_dir, "*.tif")))

n_images = len(image_paths)
n_train = int(0.85 * n_images)

train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]

batch_size = 1  # the training batch size
patch_shape = (512, 512)  # the size of patches for training

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
train_instance_segmentation = True

# Label transform is used to convert the ground-truth labels to the desired instances for finetuning Segment Anything.
# or, to learn the foreground and distances to the object centers and object boundaries for automatic segmentation.
if train_instance_segmentation:
    # Computes the distance transform for objects to jointly perform the additional decoder-based automatic instance segmentation (AIS) and finetune Segment Anything.
    label_transform = PerObjectDistanceTransform(
        distances=True,
        boundary_distances=True,
        directed_distances=False,
        foreground=True,
        instances=True,
        min_size=5
    )
else:
    # Ensures the individual object instances.to finetune the clasiscal Segment Anything.
    label_transform = torch_em.transform.label.connected_components

train_loader = torch_em.default_segmentation_loader(
    raw_paths=train_image_paths,
    raw_key=None,
    label_paths=train_label_paths,
    label_key=None,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=False,
    n_samples=50,
    raw_transform=sam_training.identity,
    sampler=MinInstanceSampler(),
)
val_loader = torch_em.default_segmentation_loader(
    raw_paths=val_image_paths,
    raw_key=None,
    label_paths=val_label_paths,
    label_key=None,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=False,
    n_samples=50,
    raw_transform=sam_training.identity,
    sampler=MinInstanceSampler(),
)

When I run the training script, I'm getting the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[68], line 2
      1 # Run training
----> 2 sam_training.train_sam(
      3     name=checkpoint_name,
      4     save_root=os.path.join(training_dir, "models"),
      5     model_type=model_type,
      6     train_loader=train_loader,
      7     val_loader=val_loader,
      8     n_epochs=n_epochs,
      9     n_objects_per_batch=n_objects_per_batch,
     10     with_segmentation_decoder=train_instance_segmentation,
     11     device=device,
     12 )

File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:184, in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
    128 def train_sam(
    129     name: str,
    130     model_type: str,
   (...)
    148     pbar_signals: Optional[QObject] = None,
    149 ) -> None:
    150     """Run training for a SAM model.
    151 
    152     Args:
   (...)
    182         pbar_signals: Controls for napari progress bar.
    183     """
--> 184     _check_loader(train_loader, with_segmentation_decoder)
    185     _check_loader(val_loader, with_segmentation_decoder)
    187     device = get_device(device)

File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:75, in _check_loader(loader, with_segmentation_decoder)
     73 if with_segmentation_decoder:
     74     if n_channels_y != 4:
---> 75         raise ValueError(
     76             "Invalid number of channels in the target data from the data loader. "
     77             "Expect 4 channel for training with an instance segmentation decoder, "
     78             f"but got {n_channels_y} channels."
     79         )
     80     check_instance_channel(y[:, 0])
     82     targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()

ValueError: Invalid number of channels in the target data from the data loader. Expect 4 channel for training with an instance segmentation decoder, but got 1 channels.

When I check the shape of the data loader output, I'm getting 'torch.Size([1, 3, 512, 512])' and 'torch.Size([1, 1, 512, 512])' for the image and label respectively. I've also checked several times to see if it is pulling any labels that have no object, and it always has an object.

constantinpape commented 1 month ago

@rodrigo-pena : thanks for all your feedback here. This helped us understand corner cases of the training with empty label images better and we will incorporate this into the next release to improve handling of these cases. It seems like it's working now for you so I am closing the issue. p.s.: from what I see here your labels are probably not optimal for finetuning micro_sam for automatic instance segmentation and it would help to label as many objects per image as possible rather than just labeling a subset. Feel free to open a separate issue if you run into bad segmentation quality due to this.

@jalexs82 : I have created a new issue #716 in order to keep the discussion of your problem separate. This makes it easier to keep the discussion focused.