m2lines / gz21_ocean_momentum

Stochastic-Deep Learning Parameterization of Ocean Momentum Forcing
MIT License
5 stars 1 forks source link

Training fails with `IndexError` when performing train/test dataset splitting #104

Open raehik opened 9 months ago

raehik commented 9 months ago

Running a simple training command with some low-resolution training data (only 100 samples) gives me an error that gets triggered in train_for_one_epoch, when we enumerate(dataloader). On main branch, run

mlflow run . --experiment-name raehik -e train --env-manager=local \
-P forcing_data_path=<path> \
-P learning_rate=0/5e-4/15/5e-5/30/5e-6 -P n_epochs=200 -P weight_decay=0.00 -P train_split=0.8 \
-P test_split=0.85 -P model_module_name=models.models1 -P model_cls_name=FullyCNN -P batchsize=4 \
-P transformation_cls_name=SoftPlusTransform -P submodel=transform3 \
-P loss_cls_name=HeteroskedasticGaussianLossV2

On main (as of the merging of #97 in early December 2023)

python src/gz21_ocean_momentum/cli/train.py \
--in-train-data-dir <path> --subdomains-file examples/cli-configs/training-subdomains-paper.yaml \
--initial-learning-rate 5.0e-4 --decay-at-epoch-milestones 15 --decay-at-epoch-milestones 30 --decay-factor 0.00 \
--train-split 0.8 --test-split 0.85 --batch-size 4 --epochs 200

gives IndexError: index <x> is out of bounds for axis 0 with size 80. The index seems between 80-320 (I've certainly seen low 80s and high 300s).

There are 320 samples in all training subdomains combined (4 spatial domains, 80% each for training). We do batching with a size of 4. I've tried investigating and tinkering with these, but I've not managed to resolve it.

Doing either of these prevents the issue from occurring:

It would seem the problem is somewhere in Subset_ or related code, or in my xarray generated from the data step.

dorchard commented 8 months ago

@CemGultekin1 I wonder if you came across something similar to this when looking at the gz code?