computational-cell-analytics / micro-sam

Segment Anything for Microscopy
https://computational-cell-analytics.github.io/micro-sam/
MIT License
336 stars 42 forks source link

Issue with RGB data and instance labels in training #716

Open constantinpape opened 5 hours ago

constantinpape commented 5 hours ago

I'm also working with RGB data, but I'm running into a slightly different problem. I have a folder with RGB images in .tif format (reshaped so that the RGB channel is first). I have set up the data loader using both option 1 and 2. Here is my version of option 2:

image_paths = sorted(glob(os.path.join(image_dir, "*.tif")))
label_paths = sorted(glob(os.path.join(segmentation_dir, "*.tif")))

n_images = len(image_paths)
n_train = int(0.85 * n_images)

train_image_paths, train_label_paths = image_paths[:n_train], label_paths[:n_train]
val_image_paths, val_label_paths = image_paths[n_train:], label_paths[n_train:]

batch_size = 1  # the training batch size
patch_shape = (512, 512)  # the size of patches for training

# Train an additional convolutional decoder for end-to-end automatic instance segmentation
train_instance_segmentation = True

# Label transform is used to convert the ground-truth labels to the desired instances for finetuning Segment Anything.
# or, to learn the foreground and distances to the object centers and object boundaries for automatic segmentation.
if train_instance_segmentation:
    # Computes the distance transform for objects to jointly perform the additional decoder-based automatic instance segmentation (AIS) and finetune Segment Anything.
    label_transform = PerObjectDistanceTransform(
        distances=True,
        boundary_distances=True,
        directed_distances=False,
        foreground=True,
        instances=True,
        min_size=5
    )
else:
    # Ensures the individual object instances.to finetune the clasiscal Segment Anything.
    label_transform = torch_em.transform.label.connected_components

train_loader = torch_em.default_segmentation_loader(
    raw_paths=train_image_paths,
    raw_key=None,
    label_paths=train_label_paths,
    label_key=None,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=False,
    n_samples=50,
    raw_transform=sam_training.identity,
    sampler=MinInstanceSampler(),
)
val_loader = torch_em.default_segmentation_loader(
    raw_paths=val_image_paths,
    raw_key=None,
    label_paths=val_label_paths,
    label_key=None,
    patch_shape=patch_shape,
    batch_size=batch_size,
    ndim=2,
    is_seg_dataset=False,
    n_samples=50,
    raw_transform=sam_training.identity,
    sampler=MinInstanceSampler(),
)

When I run the training script, I'm getting the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[68], line 2
      1 # Run training
----> 2 sam_training.train_sam(
      3     name=checkpoint_name,
      4     save_root=os.path.join(training_dir, "models"),
      5     model_type=model_type,
      6     train_loader=train_loader,
      7     val_loader=val_loader,
      8     n_epochs=n_epochs,
      9     n_objects_per_batch=n_objects_per_batch,
     10     with_segmentation_decoder=train_instance_segmentation,
     11     device=device,
     12 )

File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:184, in train_sam(name, model_type, train_loader, val_loader, n_epochs, early_stopping, n_objects_per_batch, checkpoint_path, with_segmentation_decoder, freeze, device, lr, n_sub_iteration, save_root, mask_prob, n_iterations, scheduler_class, scheduler_kwargs, save_every_kth_epoch, pbar_signals)
    128 def train_sam(
    129     name: str,
    130     model_type: str,
   (...)
    148     pbar_signals: Optional[QObject] = None,
    149 ) -> None:
    150     """Run training for a SAM model.
    151 
    152     Args:
   (...)
    182         pbar_signals: Controls for napari progress bar.
    183     """
--> 184     _check_loader(train_loader, with_segmentation_decoder)
    185     _check_loader(val_loader, with_segmentation_decoder)
    187     device = get_device(device)

File C:\ProgramData\anaconda3\envs\micro-sam\Lib\site-packages\micro_sam\training\training.py:75, in _check_loader(loader, with_segmentation_decoder)
     73 if with_segmentation_decoder:
     74     if n_channels_y != 4:
---> 75         raise ValueError(
     76             "Invalid number of channels in the target data from the data loader. "
     77             "Expect 4 channel for training with an instance segmentation decoder, "
     78             f"but got {n_channels_y} channels."
     79         )
     80     check_instance_channel(y[:, 0])
     82     targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max()

ValueError: Invalid number of channels in the target data from the data loader. Expect 4 channel for training with an instance segmentation decoder, but got 1 channels.

When I check the shape of the data loader output, I'm getting 'torch.Size([1, 3, 512, 512])' and 'torch.Size([1, 1, 512, 512])' for the image and label respectively. I've also checked several times to see if it is pulling any labels that have no object, and it always has an object.

Originally posted by @jalexs82 in https://github.com/computational-cell-analytics/micro-sam/issues/701#issuecomment-2387209037

constantinpape commented 5 hours ago

Hi @jalexs82 ,

When I check the shape of the data loader output, I'm getting 'torch.Size([1, 3, 512, 512])' and 'torch.Size([1, 1, 512, 512])' for the image and label respectively.

This is surprising. The shape of the labels should be torch.Size([1, 4, 512, 512]). If the label transform is PerObjectDistanceTransform is used the 4 channels should automatically be created. Could you maybe share a few sample images and labels of your data? (3-4 would be enough). That would be the easiest way to check what's going wrong.

I've also checked several times to see if it is pulling any labels that have no object, and it always has an object.

The problem you report is not related to empty labels.

anwai98 commented 4 hours ago

Hi @jalexs82,

Re: issue with output label channels: It looks like you are not passing the label_transform argument to the dataloaders. You can do this using:

train_loader = torch_em.default_segmentation_loader(
    ...,
    label_transform=label_transform,
    sampler=MinInstanceSampler(min_size=25),  # I would recommend using the sampler with filtering smaller objects, as the label transform filters them out later on, which might causes troubles.
)
val_loader = torch_em.default_segmentation_loader(
    ...,
    label_transform=label_transform,
    sampler=MinInstanceSampler(min_size=25),  # same recommendation as above
)

Let us know if this works. If you still have issues, I would recommend to do as @constantinpape recommended, share a few sample images and respective labels of your data.