microsoft / torchgeo

TorchGeo: datasets, samplers, transforms, and pre-trained models for geospatial data
https://www.osgeo.org/projects/torchgeo/
MIT License
2.67k stars 328 forks source link

Easier way to use Data Processing steps outside of datamodule #1780

Open nilsleh opened 9 months ago

nilsleh commented 9 months ago

Summary

Normalization and Augmentations are defined in the on_after_batch_transfer() function of the Datamodules to compute them on GPU like recommended from lightning. However, a downside of this is that you always have to pass the datamodule ino .fit and .test. Especially, for the latter, it can be convenient to test on separate dataloaders, however, those are then just "raw" dataloaders without normalization etc. being applied. Took me a minute to find that this was the reason for the funky test results. Currently, I am writing a custom collate_fn and set it to the dataloader that I am getting from a datamodule, however, it would be nice if this could be handled more easily. Open to hear thoughts about this, or suggestions for an easier ways to handle this than what I am doing at the moment.

Rationale

Sometimes I would like to test a model on different datasets and if a torchgeo datamodule is available, it is convenient to just retrieve a configured dataloder from an implemented datamodule.

Implementation

Maybe it could be possible to add a flag to return a dataloader with a collate function based on the on_afer_batch_transfer augmentation.

Alternatives

Currently I am doing something like this:

datamodule = ETCI2021DataModule(root=".", download=True, num_workers=4, batch_size=32)
datamodule.setup("fit")

def collate(batch: list[dict[str, torch.Tensor]]):
    """Collate fn to include augmentations."""
    images = [item["image"] for item in batch]
    labels = [item["label"] for item in batch]

    inputs = torch.stack(images)
    targets = torch.stack(labels)
    return datamodule.on_after_batch_transfer({"image": inputs, "mask": targets})

val_dataloader = datamodule.val_dataloader()
val_dataloader.collate_fn = collate
adamjstewart commented 9 months ago

I can understand why you would want to be able to use a dataset if a data module doesn't exist, but why would you want to use a dataset if a data module does exist?

nilsleh commented 9 months ago

In order to do trainer.validate(model, dataloaders=datamodule.val_dataloader()) but not having to implement my own normalization scheme as a collate fn for every dataloader from a datamodule I want to use. So for example say I train one model and want to validate it on a bunch of datasets, then I could pass multiple dataloaders from different datasets or datamodules to trainer.validate()

adamjstewart commented 9 months ago

But why not use trainer.validate(model, datamodule=datamodule) for all data modules?

nilsleh commented 9 months ago

If you pass a datamodule, it will only select the predefined validation loader and validate on that, but maybe I would like to validate on the train set and the validation set, for example when taking a pre-trained model and checking performance without training. Might also be relevant if you try something like cross validation, where you split your train/val sets. In my case, I am trying conformal prediction, where you need to take a subset of the validation set to create a separate calibration set and use the the model with that, so you need to control "which" split dataloader to apply validation to.

robmarkcole commented 3 months ago

I think at a minimum we should improve the documentation to state when augmentations are and are not applied. For example I assumed they are performed in the dataset get_item, but they are not

isaaccorley commented 3 months ago

Just want to clarify that for the majority of datamodules there are no augmentations applied, only a normalization of the images. We try not to prescribe which augmentations should be used for what dataset as this should be left to the user.

There are 2 options:

robmarkcole commented 3 months ago

Thanks for the clarification - it makes sense that there are basic augmentations to always apply during training and which we don't need to inspect (ie normalisation) and others to experiment with and visualise for sanity checking. Therefore is it a reasonable workflow to pass the later kind to the dataset in setup, and still have the former applied at the data module level?