mlexchange / mlex_dlsia_segmentation_prototype

Other
3 stars 3 forks source link

Refactor `TiledDataset` into `TiledDataset` and `TiledMaskedDataset` #32

Closed Wiebke closed 3 weeks ago

Wiebke commented 3 weeks ago

The previous iteration of TiledDataset used boolean parameters is_full_segmentation and is_training, as well as the setting attributes mask_client and mask_idx to None if no Tiled client with mask information was provided.

This refactor converts TiledDataset to be equivalent to have the functionality previously used with is_full_segmentation. It is defined based on a Tiled client with data. Optionally a set of indices can be provided. This is useful for crafting data sets with any subsets of indices, but is also applied when mask information is given during training.

TiledMaskedDataset is a subclass of TiledDataset that additionally requires a Tiled client with mask information. This client is expected to contain a list of integers in within .metadata["mask_idx"], and contain actual mask data under the key "mask". Presence of both is asserted.

The parameter is_training is still in use for the function initialize_tiled_datasets (which has been moved from utils.py to tiled_dataset.py), but this may change in upcoming refactoring of IOParameters.

To summarize: -TiledDataset(data_client, mask_client=None, is_training=False, is_full_segmentation=True)TiledDataset(data_client) -TiledDataset(data_client, mask_client, is_training=False, is_full_segmentation=False)TiledDataset(data_client, mask_client.metadata["mask_idx"]) -TiledDataset(data_client, mask_client, is_training=True, is_full_segmentation=False)TiledMaskedDataset(data_client, mask_client)

We may need to consider bringing back the transform parameter that used to convert to torch.Tensor and adapt downstream procedures to operate on tensors rather than np.array.

Wiebke commented 3 weeks ago

Thanks for taking a look and testing the full pipeline on your end! I fixed the issue in partial_inference by correcting the initialization of the dataset in that function. Due to accidentally setting is_training=True, a TiledMaskedDataset is initialized. With is_training=False, a TiledDataset that uses the mask indices for iteration, no tuple unpacking in train.py line 211 should be needed.

I would like to defer the correct initialization of io_parameters and changes to the validate_parameter function to refactoring the parameter setup and validation. This does indeed momentarily break the main function of segment.py (such that is runs partial inference), but this will be addressed then.