clessig / atmorep

AtmoRep model code
MIT License
44 stars 11 forks source link

[BUG] training fails with increased number of levels #66

Open jpolz opened 4 days ago

jpolz commented 4 days ago

Describe the bug A clear and concise description of what the bug is.

To Reproduce Steps to reproduce the behavior:

  1. Start from develop
  2. Change number of levels to 10 using an appropriate dataset (here: /p/scratch/hclimrep/polz1/data/era5_1deg/months/era5_y2014_2016_res025_chunk8.zarr/)
  3. Activate virtual env and run train.py
  4. See error

Expected behavior Training proceeds as usual.

Screenshots

0: Traceback (most recent call last):
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep_train_13290946/train.py", line 229, in <module>
0:     train()
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep_train_13290946/train.py", line 222, in train
0:     trainer.run()
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep/core/trainer.py", line 188, in run
0:     self.train( epoch)
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep/core/trainer.py", line 230, in train
0:     batch_data = self.model.next()
0:                  ^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep/core/atmorep_model.py", line 161, in next
0:     return next(self.data_loader_iter)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
0:     data = self._next_data()
0:            ^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
0:     return self._process_data(data)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-package
0: s/torch/utils/data/dataloader.py", line 1370, in _process_data
0:     data.reraise()
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
0:     raise exception
0: IndexError: Caught IndexError in DataLoader worker process 0.
0: Original Traceback (most recent call last):
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
0:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
0:            ^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/pyenv-jureca/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 42, in fetch
0:     data = next(self.dataset_iter)
0:            ^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep/datasets/multifield_data_sampler.py", line 235, in __iter__
0:     sources = self.pre_batch( sources, token_infos )
0:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/hclimrep/polz1/atmorep/atmorep/training/bert.py", line 51, in prepare_batch_BERT_multifield
0:     ret = bert_f( cf, ifield, field_data, token_info, rngs[rng_idx])
0:                                                       ~~~~^^^^^^^^^
0: IndexError: list inde
0: x out of range

Hardware and environment:

Additional context I created a new dataset containing model levels 23, 29, 41, 53, 60, 96, 105, 114, 123, 137 using all of them to train a singleformer. A maximum number of levels is harcoded in trainer.py and can be increased to solve this issue. It should be discussed if a flexible version, e.g. using len(cf.fields[0][1]), is what we want. If at some point in the future one would desire different level numbers for different fields this would potentially raise another error (only looking at the first field). In principle the maximum number of levels per field would work.

from trainer.py:

image

iluise commented 4 days ago

Hi, can't you take the max looping over all len(cf.fields[i][1]) instead? In principle we designed the architecture to be flexible with the number of levels per fields, but it's something that needs to be checked and validated again. Do you mind doing a quick check and eventually report the stack trace?

thanks!

jpolz commented 4 days ago

The required setup would be train multi with 2 or more fields with different numbers of levels, right? I also believe that I'd need to set it to max([ len(f[2]) for f in cf.fields])+1 from my testing experiences, although I don't understand why the +1 is necessary

jpolz commented 4 days ago

I used max([ len(f[2]) for f in cf.fields])+1 now and running train_multi.py with this config works (run on wandb):

  cf.fields = [ 
                [ 'velocity_u', [ 1, 1024, ['velocity_v', 'temperature'], 0 ], 
                                [ 96, 105, 114, 123, 137 ], 
                                [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ],
                [ 'velocity_v', [ 1, 1024, ['velocity_u', 'temperature'], 1 ], 
                                [ 96, ], 
                                [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ], 
                [ 'specific_humidity', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 2 ], 
                              [ 96, 105, 114, ], 
                              [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ],
                [ 'velocity_z', [ 1, 1024, ['velocity_u', 'velocity_v', 'temperature'], 3 ], 
                              [ 96, 105, 114, 123 ], 
                              [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05] ],
                 [ 'temperature', [ 1, 1024, ['velocity_u', 'velocity_v', 'specific_humidity'], 3 ], 
                              [ 96, 105, ], 
                              [12, 3, 6], [3, 18, 18], [0.5, 0.9, 0.2, 0.05], 'local' ],
              ]

Running train.py with this config also works (https://wandb.ai/atmorep/stratorep/runs/ugkgfp2f/overview):

  cf.fields = [ [ 'temperature', [ 1, 1024, [ ], 0 ], 
                               [ 23, 29, 41, 53, 60, 96, 105, 114, 123, 137
                                ], 
                               [12, 2, 4], [3, 27, 27], [0.5, 0.9, 0.2, 0.05], 'local' ] ]
  cf.fields_prediction = [ [cf.fields[0][0], 1.] ]
clessig commented 4 days ago

Great! Can you open a PR for the fix. Thanks!

iluise commented 4 days ago

Super! Thanks a lot! Can you check the evaluation step as well (e.g. doing global forecasting for just one single date) so in case we include also those fixes in the MR?

jpolz commented 4 days ago

Yes, I can do that too. In that case the PR should stay open. Thanks for the support.