This PR refactors the CorrDiff dataloader code away from the training loop in the example. This is intended to enable the easier addition of new datasets to train and inference CorrDiff.
Add a DownscalingDataset base class in datasets/base.py that defines an interface that datasets should follow.
Move training/dataset.py to datasets/cwb.py and modify it to conform to the interface.
Remove dataset creation from training/training_loop.py. The dataset is instead created in train.py and passed to the training loop function.
Remove references to specifics of the CWB dataset from train.py, training/training_loop.py and generate.py. Use the functionality of the DownscalingDataset interface instead.
Move the dataloader initialization keywords into the dataset section of the config files. These are passed to the dataset constructor and can be different for different datasets.
A few minor cleanup changes, such as removing unused code/variables.
Testing
Verified that training and generation loops run ok.
Not verified that training loop produces convergent training runs - needs discussion if this should be done
Modulus Pull Request
Description
This PR refactors the CorrDiff dataloader code away from the training loop in the example. This is intended to enable the easier addition of new datasets to train and inference CorrDiff.
This enhancement is listed in issue https://github.com/NVIDIA/modulus/issues/353.
Changes
DownscalingDataset
base class indatasets/base.py
that defines an interface that datasets should follow.training/dataset.py
todatasets/cwb.py
and modify it to conform to the interface.training/training_loop.py
. The dataset is instead created intrain.py
and passed to the training loop function.train.py
,training/training_loop.py
andgenerate.py
. Use the functionality of theDownscalingDataset
interface instead.dataset
section of the config files. These are passed to the dataset constructor and can be different for different datasets.Testing
Checklist
Dependencies