clessig / atmorep

AtmoRep model code
MIT License
44 stars 11 forks source link

temporal interpolation #17

Open iluise opened 4 months ago

iluise commented 4 months ago

Check if temporal interpolation still works, by:

clessig commented 3 months ago

Temporal interpolation is working per se in that the model can be run with this mode. However, the output needs to be verified.

clessig commented 3 months ago

Please use the github issues to track your work, as discussed--not private chats. Thanks.

sbAsma commented 3 months ago

Current status:

sbAsma commented 3 months ago

New update:

clessig commented 3 months ago

Which

New update:

  • (+1) Pulled changes from iluise/head, temporal_interpolation inference works now on JUWELS Booster
  • (-1) Cannot run the code on HDFML anymore. Error message:

    ...
    0:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py", line 256, in forward
    0: [rank0]:     outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
    0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^
    0: [rank0]: RuntimeError: No available kernel. Aborting execution.
    
    ...

Which torch version are you using? You need >= 2.3

sbAsma commented 3 months ago

Which

New update:

  • (+1) Pulled changes from iluise/head, temporal_interpolation inference works now on JUWELS Booster
  • (-1) Cannot run the code on HDFML anymore. Error message:

    ...
    0:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py", line 256, in forward
    0: [rank0]:     outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
    0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^
    0: [rank0]: RuntimeError: No available kernel. Aborting execution.
    
    ...

Which torch version are you using? You need >= 2.3

I am using 2.4

clessig commented 3 months ago

Usually there's an error message why there is no available kernel. Can you send a more complete stack trace.

sbAsma commented 3 months ago
+ export SRUN_CPUS_PER_TASK=12
+ SRUN_CPUS_PER_TASK=12
+ CONFIG_DIR=/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268
+ mkdir /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268
+ cp /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/atmorep/core/evaluate.py /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268
+ echo /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268/evaluate.py
+ srun --label --cpu-bind=v /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/../virtual_envs/venv_hdfml/bin/python -u /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268/evaluate.py
1: cpu_bind=THREADS - hdfmlc15, task  1  1 [2103]: mask 0xfff000 set
0: cpu_bind=THREADS - hdfmlc15, task  0  0 [2153]: mask 0xfff set
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py:31: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
1:   from torch.distributed.optim import ZeroRedundancyOptimizer
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py:31: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
0:   from torch.distributed.optim import ZeroRedundancyOptimizer
0: [W809 09:57:13.478940966 socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [hdfmlc15.hdfml]:1345 (errno: 97 - Address family not supported by protocol).
1: [W809 09:57:13.532948173 socket.cpp:697] [c10d] The client socket cannot be initialized to connect to [hdfmlc15.hdfml]:1345 (errno: 97 - Address family not supported by protocol).
0: wandb: Tracking run with wandb version 0.17.6
0: wandb: W&B syncing is set to `offline` in this directory.  
0: wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
1: /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/init.py:453: UserWarning: Initializing zero-element tensors is a no-op
1:   warnings.warn("Initializing zero-element tensors is a no-op")
0: /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/init.py:453: UserWarning: Initializing zero-element tensors is a no-op
0:   warnings.warn("Initializing zero-element tensors is a no-op")
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py:441: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
1:   mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) )
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py:441: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
0:   mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) )
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:718.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:720.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transf
1: ormer/transformer_attention.py:256: UserWarning: Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:201.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:722.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.)
1:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: [rank1]: Traceback (most recent call last):
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268/evaluate.py", line 76, in <module>
1: [rank1]:     Evaluator.evaluate( mode, model_id, file_path, options)
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
1: [rank1]:     func( cf, model_id, model_epoch, devices, args)
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 223, in temporal_interpolation
1: [rank1]:     Evaluator.run( cf, model_id, model_epoch, devices)
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 66, in run
1: [rank1]:     evaluator.validate( 0, cf.BERT_strategy)
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 380, in validate
1: [rank1]:     preds, atts = self.model( batch_data)
1: [rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project
1: 1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
1: [rank1]:     return self._call_impl(*args, **kwargs)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
1: [rank1]:     return forward_call(*args, **kwargs)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 166, in forward
1: [rank1]:     pred = self.net.forward( xin)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 499, in forward
1: [rank1]:     fields_embed, att = self.forward_encoder_block( ib, fields_embed) 
1: [rank1]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   F
1: ile "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 533, in forward_encoder_block
1: [rank1]:     y, att = self.checkpoint( self.encoders[ifield].heads[iblock], *fields_in)
1: [rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_base.py", line 26, in checkpoint_wrapper
1: [rank1]:     return cmodule(*kwargs)
1: [rank1]:            ^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
1: [rank1]:     return self._call_impl(*args, **kwargs)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
1: [rank1]:     return forward_call(*args, **kwargs)
1: [rank1]:
1:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py", line 256, in forward
1: [rank1]:     outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^
1: [rank1]: RuntimeError: No available kernel. Aborting execution.
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:718.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:720.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:201.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:722.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: /p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py:256: UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.)
0:   outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: Traceback (most recent call last):
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268/evaluate.py", line 76, in <module>
0:     Evaluator.evaluate( mode, model_id, file_path, options)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0:     func( cf, model_id, model_epoch, devices, args)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 223, in temporal_interpolation
0:     Evaluator.run( cf, model_id, model_epoch, devices)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 66, in run
0:     evaluator.validate( 0, cf.BERT_strategy)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 380, in validate
0:     preds, atts = self.model( batch_data)
0:                   ^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
0:     return self._call_impl(*args, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
0:     return forward_call(*args, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 166, in forward
0:     pred = self.net.forward( xin)
0:            ^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 499, in forward
0:     fields_embed, att = self.forward_encoder_block( ib, fields_embed)
0:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 533, in forward_encoder_block
0:     y, att = self.checkpoint( self.encoders[ifield].heads[iblock], *fields_in)
0:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_base.py", line 26, in checkpoint_wrapper
0:     return cmodule(*kwargs)
0:            ^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
0:     return self._call_impl(*args, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
0:     return forward_call(*args, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py", line 256, in forward
0:     outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0:            ^^^^^^^^^^^^^^^^^^^^^
0: RuntimeError: No available kernel. Aborting execution.
0: [rank0]: Traceback (most recent call last):
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69268/evaluate.py", line 76, in <module>
0: [rank0]:     Evaluator.evaluate( mode, model_id, file_path, options)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0: [rank0]:     func( cf, model_id, model_epoch, devices, args)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 223, in temporal_interpolation
0: [rank0]:     Evaluator.run( cf, model_id, model_epoch, devices)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 66, in run
0: [rank0]:     evaluator.validate( 0, cf.BERT_strategy)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 380, in validate
0: [rank0]:     preds, atts = self.model( batch_data)
0: [rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project
0: 1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
0: [rank0]:     return self._call_impl(*args, **kwargs)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
0: [rank0]:     return forward_call(*args, **kwargs)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 166, in forward
0: [rank0]:     pred = self.net.forward( xin)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 499, in forward
0: [rank0]:     fields_embed, att = self.forward_encoder_block( ib, fields_embed) 
0: [rank0]:                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   F
0: ile "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 533, in forward_encoder_block
0: [rank0]:     y, att = self.checkpoint( self.encoders[ifield].heads[iblock], *fields_in)
0: [rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_base.py", line 26, in checkpoint_wrapper
0: [rank0]:     return cmodule(*kwargs)
0: [rank0]:            ^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
0: [rank0]:     return self._call_impl(*args, **kwargs)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
0: [rank0]:     return forward_call(*args, **kwargs)
0: [rank0]:
0:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/transformer/transformer_attention.py", line 256, in forward
0: [rank0]:     outs = self.att( qs, ks, vs).transpose( -3, -2).flatten( -2, -1).reshape(s)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: RuntimeError: No available kernel. Aborting execution.
0: wandb: You can sync this run to the cloud by running:
0: wandb: wandb sync /p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/wandb/offline-run-20240809_095714-gxo1grfj
0: wandb: Find logs at: ./wandb/offline-run-20240809_095714-gxo1grfj/logs
0: wandb: WARNING The new W&B backend becomes opt-out in version 0.18.0; try it out with `wandb.require("core")`! See https://wandb.me/wandb-core for more information.
srun: error: hdfmlc15: task 1: Exited with exit code 1
srun: error: hdfmlc15: task 0: Terminated
srun: Force Terminated StepId=69268.0
+ echo 'Finished job.'
+ date

FYI, the job runs without any problem on JUWELS Booster

clessig commented 3 months ago

The warning says

UserWarning: Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:201.)

Are this V100 GPUs?

clessig commented 3 months ago

You can try with:

with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):

replaced by

with torch.nn.attention.sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):

Then it should fall back to an attention implementation that works on all GPU

sbAsma commented 3 months ago

You can try with:

with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):

replaced by

with torch.nn.attention.sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]):

Then it should fall back to an attention implementation that works on all GPU

I think it worked, but I have a new error:

0: Traceback (most recent call last):
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69308/evaluate.py", line 76, in <module>
0:     Evaluator.evaluate( mode, model_id, file_path, options)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0:     func( cf, model_id, model_epoch, devices, args)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 223, in temporal_interpolation
0:     Evaluator.run( cf, model_id, model_epoch, devices)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 66, in run
0:     evaluator.validate( 0, cf.BERT_strategy)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 401, in validate
0:     self.log_validate( epoch, it, log_sources, log_preds)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 617, in log_validate
0:     self.log_validate_BERT( epoch, bidx, log_sources, log_preds)
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 828, in log_validate_BERT
0:     write_BERT( cf.wandb_id, epoch, batch_idx,
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/datasets/data_writer.py", line 112, in write_BERT
0:     write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] )
0:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/datasets/data_writer.py", line 24, in write_item
0:     ds_batch_item.create_dataset( 'data', data=data)
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 1111, in create_dataset
0:     return self._write_op(self._create_dataset_nosync, name, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 952, in _write_op
0:     return f(*args, **kwargs)
0:            ^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 1126, in _create_dataset_nosync
0:     a = array(data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs)
0:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/creation.py", line 441, in array
0:     z = create(**kwargs)
0:         ^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/creation.py", line 227, in create
0:     z = Array(
0:         ^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 170, in __init__
0:     self._load_metadata()
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 193, in _load_metadata
0:     self._load_metadata_nosync()
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 202, in _load_metadata_nosync
0:     meta_bytes = self._store[mkey]
0:                  ~~~~~~~~~~~^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/storage.py", line 1863, in __getitem__
0:     with self.zf.open(key) as f:  # will raise KeyError
0:          ^^^^^^^^^^^^^^^^^
0:   File "/p/software/hdfml/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/zipfile.py", line 1567, in open
0:     raise BadZipFile("Bad magic number for file header")
0: zipfile.BadZipFile: Bad magic number for file header
0: [rank0]: Traceback (most recent call last):
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_69308/evaluate.py", line 76, in <module>
0: [rank0]:     Evaluator.evaluate( mode, model_id, file_path, options)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0: [rank0]:     func( cf, model_id, model_epoch, devices, args)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 223, in temporal_interpolation
0: [rank0]:     Evaluator.run( cf, model_id, model_epoch, devices)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 66, in run
0: [rank0]:     evaluator.validate( 0, cf.BERT_strategy)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 401, in validate
0: [rank0]:     self.log_validate( epoch, it, log_sources, log_preds)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep
0: /atmorep/atmorep/core/trainer.py", line 617, in log_validate
0: [rank0]:     self.log_validate_BERT( epoch, bidx, log_sources, log_preds)
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/core/trainer.py", line 828, in log_validate_BERT
0: [rank0]:     write_BERT( cf.wandb_id, epoch, batch_idx, 
0: [rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/datasets/data_writer.py", line 112, in write_BERT
0: [rank0]:     write_item(ds_field, sample, field[1][bidx], levels[fidx], sources_coords[fidx][bidx] )
0: [rank0]:   File "/p/home/jusers/semcheddine1/hdfml/dev_atmorep/atmorep/atmorep/datasets/data_writer.py", line 24, in write_item
0: [rank0]:     ds_batch_item.create_dataset( 'data', data=data)
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 1111, in create_dataset
0: [rank0]:     return self._write_op(self._create_dataset_nosync, name, **kwargs)
0: 
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 952, in _write_op
0: [rank0]:     return f(*args, **kwargs)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/hierarchy.py", line 1126, in _create_dataset_nosync
0: [rank0]:     a = array(data, store=self._store, path=path, chunk_store=self._chunk_store, **kwargs)
0: [rank0]:         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/creation.py", line 441, in array
0: [rank0]:     z = create(**kwargs)
0: [rank0]:         ^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/s
0: ite-packages/zarr/creation.py", line 227, in create
0: [rank0]:     z = Array(
0: [rank0]:         ^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 170, in __init__
0: [rank0]:     self._load_metadata()
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 193, in _load_metadata
0: [rank0]:     self._load_metadata_nosync()
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_hdfml/lib/python3.11/site-packages/zarr/core.py", line 202, in _load_metadata_nosync
0: [rank0]:     meta_bytes = self._store[mkey]
0: [rank0]:                  ~~~~~~~~~~~^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_
0: hdfml/lib/python3.11/site-packages/zarr/storage.py", line 1863, in __getitem__
0: [rank0]:     with self.zf.open(key) as f:  # will raise KeyError
0: [rank0]:          ^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/software/hdfml/stages/2024/software/Python/3.11.3-GCCcore-12.3.0/lib/python3.11/zipfile.py", line 1567, in open
0: [rank0]:     raise BadZipFile("Bad magic number for file header")
sbAsma commented 3 months ago

When running inference on temporal_interpolation, the written data (target and pred) is not detokenized, i.e. that the output is of shape (216, 3, 9, 9) per level (similar to BERT data writing). Should the detokenization be changed to before the data writing, or keep it as a part of the pre-processing before plotting the data ?

clessig commented 3 months ago

When running inference on temporal_interpolation, the written data (target and pred) is not detokenized, i.e. that the output is of shape (216, 3, 9, 9) per level (similar to BERT data writing). Should the detokenization be changed to before the data writing, or keep it as a part of the pre-processing before plotting the data ?

We don't want to have a custom detokenize in the core code for every application. It should then be done in the analysis. That said, reconstructing individual time slices is needed for forecast and probably for other cases. So if one factors it with the detokenization of forecast then it might be ok to have it in the core code.

sbAsma commented 3 months ago

update: Training works fine. I have been able to restart training without any issue.

sbAsma commented 3 months ago

@clessig @iluise is global_forecast mode working for you? I wanted to check what is happening in the logging function, so I'd make some adaptations for temporal_interpolation. The code is crashing at the line 215 in multifield_data_sampler.py

clessig commented 3 months ago

Global forecast is working. What is the error?

sbAsma commented 3 months ago

Global forecast is working. What is the error?

0:   mloaded = torch.load( utils.get_model_filename( model, model_id, epoch) )
0: Traceback (most recent call last):
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_10169185/evaluate.py", line 76, in <module>
0:     Evaluator.evaluate( mode, model_id, file_path, options)
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0:     func( cf, model_id, model_epoch, devices, args)
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 176, in global_forecast
0:     evaluator.validate( 0, cf.BERT_strategy)
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/trainer.py", line 370, in validate
0:     batch_data = self.model.next()
0:                  ^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 160, in next
0:     return next(self.data_loader_iter)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
0:     data = self._next_data()
0:            ^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
0:     return self._process_data(data)
0:            ^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
0:     data.reraise()
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
0:     raise exception
0: RuntimeError: Caught RuntimeError in DataLoader worker process 11.
0: Original Traceback (most recent call last):
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
0:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
0:            ^^^^^^^^^^^^^^^^^^^^
0:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 42, in fetch
0:     data = next(self.dataset_iter)
0:            ^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in __iter__
0:     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
0:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in
0:  <listcomp>
0:     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
0:                ^^^^^^^^^^^^^^^^^^^^^^^^^^
0: RuntimeError: stack expects each tensor to be equal size, but got [5, 12, 15, 48, 3, 9, 9] at entry 0 and [5, 12, 21, 48, 3, 9, 9] at entry 14
0: 
0: [rank0]: Traceback (most recent call last):
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_10169185/evaluate.py", line 76, in <module>
0: [rank0]:     Evaluator.evaluate( mode, model_id, file_path, options)
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
0: [rank0]:     func( cf, model_id, model_epoch, devices, args)
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 176, in global_forecast
0: [rank0]:     evaluator.validate( 0, cf.BERT_strategy)
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/trainer.py", line 370, in validate
0: [rank0]:     batch_data = self.model.next()
0: [rank0]:                  ^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 160, in next
0: [rank0]:     return next(self.data_loader_iter)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [ra
0: nk0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
0: [rank0]:     data = self._next_data()
0: [rank0]:            ^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
0: [rank0]:     return self._process_data(data)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
0: [rank0]:     data.reraise()
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
0: [rank0]:     raise exception
0: [rank0]: RuntimeError: Caught RuntimeError in DataLoader worker process 11.
0: [rank0]: Original Traceback (most recent call last)
0: :
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
0: [rank0]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 42, in fetch
0: [rank0]:     data = next(self.dataset_iter)
0: [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in __iter__
0: [rank0]:     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
0: [rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in <listcomp>
0: [rank0]
0: :     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
0: [rank0]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^
0: [rank0]: RuntimeError: stack expects each tensor to be equal size, but got [5, 12, 15, 48, 3, 9, 9] at entry 0 and [5, 12, 21, 48, 3, 9, 9] at entry 14
0: 
1: [rank1]: Traceback (most recent call last):
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/atmorep/run/atmorep_eval_10169185/evaluate.py", line 76, in <module>
1: [rank1]:     Evaluator.evaluate( mode, model_id, file_path, options)
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 118, in evaluate
1: [rank1]:     func( cf, model_id, model_epoch, devices, args)
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/evaluator.py", line 176, in global_forecast
1: [rank1]:     evaluator.validate( 0, cf.BERT_strategy)
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/trainer.py", line 370, in validate
1: [rank1]:     batch_data = self.model.next()
1: [rank1]:                  ^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/core/atmorep_model.py", line 160, in next
1: [rank1]:     return next(self.data_loader_iter)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [ra
1: nk1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
1: [rank1]:     data = self._next_data()
1: [rank1]:            ^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1344, in _next_data
1: [rank1]:     return self._process_data(data)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1370, in _process_data
1: [rank1]:     data.reraise()
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/_utils.py", line 706, in reraise
1: [rank1]:     raise exception
1: [rank1]: RuntimeError: Caught RuntimeError in DataLoader worker process 11.
1: [rank1]: Original Traceback (most recent call last)
1: :
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
1: [rank1]:     data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/project1/deepacf/atmo-rep/semcheddine1/dev_atmorep/virtual_envs/venv_jwc/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 42, in fetch
1: [rank1]:     data = next(self.dataset_iter)
1: [rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in __iter__
1: [rank1]:     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
1: [rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]:   File "/p/home/jusers/semcheddine1/juwels/dev_atmorep/atmorep/atmorep/datasets/multifield_data_sampler.py", line 215, in <listcomp>
1: [rank1]
1: :     sources = [torch.stack(sources_field).transpose(1,0) for sources_field in sources]
1: [rank1]:                ^^^^^^^^^^^^^^^^^^^^^^^^^^
1: [rank1]: RuntimeError: stack expects each tensor to be equal size, but got [5, 12, 15, 48, 3, 9, 9] at entry 0 and [5, 12, 21, 48, 3, 9, 9] at entry 14
iluise commented 3 months ago

Hi @sbAsma,

I just tried running global_forecast as well and it works for me. you sure it's not something you added after? Maybe rebasing to main helps in this case.

Best,

Ilaria

sbAsma commented 3 months ago

Hi @sbAsma,

I just tried running global_forecast as well and it works for me. you sure it's not something you added after? Maybe rebasing to main helps in this case.

Best,

Ilaria

It works fine for the previous version that was in main, I didn't try with the new one. It didn't work for the code in the branch develop

sbAsma commented 3 months ago

While trying to adapt log_validate_forecast() to handle temporal_interpolation logging, I noticed that source and target have very different shapes after detokenization (aside from the temporal steps). I have the following:

source.shape = (14, 5, 36, 216, 432)
target.shape = (224, 5, 9, 54, 108)

Should I generalize to source or target shape? because this is preventing from the looping over batch elements. FYI, I set batch_size = 14 because I am having a memory issue

clessig commented 3 months ago

It seems there is some reordering missing at this point in the code:

224 / 16 = 14 54 4 = 216 108 4 = 432

so that batch dimensions collects the different spatial neighborhoods. This should be reordered later in the code.

sbAsma commented 3 months ago

It seems there is some reordering missing at this point in the code:

224 / 16 = 14 54 4 = 216 108 4 = 432

so that batch dimensions collects the different spatial neighborhoods. This should be reordered later in the code.

While waiting for feedback, I went with (14, 5, 36, 216, 432) shape. Additionally, I had to lower the batch_size once again (batch_size=10 now). Some randomly picked results can be found here , where you can see regions that are missing.

For each date, I have the data shape: (10, 5, 9, 216, 432) which doesn't translate to (5, 9, 721, 1440). I'll do more investigations while waiting for feedback

EDIT1: I tried extracting nonzero elements from the target arrays. I have different values for different dates:

All these values are smaller than 10 x 216 x 432, which leads to the conclusion: Duplicated data points + not all data is being processed. I'd appreciate some tips on what to do next.

EDIT2: I am understanding that in temporal_interpolation mode, 36 hrs tokens are randomly picked, and masking is done according to time axis. At first, I thought that data would be processed in a global manner. This leads me to conclude that patching tokens together at the end of the process might not make sens after all.

clessig commented 3 months ago

Are you running global_forecast with masking_mode temporal interpolation?

sbAsma commented 3 months ago

Are you running global_forecast with masking_mode temporal interpolation?

I ran temporal_interpolation mode and used log_validate_forecast() with some tweaking (instead of log_validate_BERT())

sbAsma commented 2 months ago

@iluise @clessig I wanted to implement a function that tests target data vs. ERA5 data. This seems to be implemented in the function test_datetime() in atmorep/atmorep/tests/validation_test.py. Meanwhile, the datetime values are being tested in the function test_coordinates(). Am I getting it right? FYI, I am working in the branch develop.

iluise commented 2 months ago

Hi @sbAsma, yes it's exactly as you said. so test_datetime tests target vs ERA5 (watch out for the outcome because it skips the check if the ERA5 dataset is not found, but it throws a warning). While test_cooredinates tests that the coordinates (lat, lon and time) are matching between target and predictions.

As we suggest in the guidelines you should pull the develop branch but then call it asma/origin and then put your changes on asma/head so you are always able to isolate your changes, even if develop gets new commits.

sbAsma commented 2 months ago

Hi @sbAsma, yes it's exactly as you said. so test_datetime tests target vs ERA5 (watch out for the outcome because it skips the check if the ERA5 dataset is not found, but it throws a warning). While test_cooredinates tests that the coordinates (lat, lon and time) are matching between target and predictions.

As we suggest in the guidelines you should pull the develop branch but then call it asma/origin and then put your changes on asma/head so you are always able to isolate your changes, even if develop gets new commits.

@iluise thanks for the reminder. I kept procrastinating on this. I finally set the branches names as suggested. I have a question now, if develop is the branch that is up to date with the remote develop, and asma/head is my working branch, what does asma/origin contain?

iluise commented 2 months ago

@sbAsma asma/origin is your stable copy of develop, so you know at which point you have started wrt the develop branch. Then you have asma/head that is a copy of asma/origin+ your changes.

This is useful as in case of new developments inside develop that you actually miss because multiple people work on different parts of the code, one can do cleaner comparisons and understand where you started from without mixing your changes and the changes within develop.

sbAsma commented 2 months ago

@iluise I am not sure if I fully grasped the idea. Does this mean that each time I want to push something to the remote repo, I need to merge, first, asma/head with asma/origin, then push asma/origin?

iluise commented 2 months ago

@sbAsma yes exactly, you first update asma/origin to be in sync with develop and then you merge back asma/head into asma/origin so the commit history is cleaner.

sbAsma commented 2 months ago

@iluise When I pull the new updates from the remote develop branch (as the last step before pushing my changes), do I pull them in the local develop branch, or the asma/origin branch?

clessig commented 2 months ago

You need to rebase origin and head.

sbAsma commented 2 months ago

You need to rebase origin and head.

As far as I am understanding rebase, this is what I did. I am asking about the steps when I want to push the changes that are in asma/origin.

sbAsma commented 2 months ago

@clessig I have the following questions:

1 - I wrote a test function that checks if the right timesteps are masked by comparing timesteps in target vs masked time steps in source. Is it enough to do it for target or should I include pred too? 2 - I noticed that I didn't fully understand what was meant by "computing RMSE", which was discussed during our first meeting. Can you please further elaborate on this? 3 - We also talked about making a video out of the predicted frames of the masked time steps. Currently, temporal interpolation is achieved over randomly selected tokens, which doesn't amount to a frame. I might be not getting the full picture here. Can you please further explain this to me? 4 - I was wondering whether it would be useful to have the option of doing temporal interpolation on either randomly selected tokens, or global data. What do you think?

clessig commented 2 months ago
  1. Please include target and pred
  2. sqrt( torch.nn.MSE( pred - target))
  3. Temporal interpolation should be for a window of tokens in time of a specified length and covering an entire neighborhood, with perhaps the position of the window in time being sampled. Even when the window length is 1 token, this corresponds to 3 x 1h time steps, which can be used to produce a video.
  4. For an application one wants to apply temporal interpolation to a global frame.
sbAsma commented 2 months ago
  1. Will do
  2. There already is a function that computes the RMSE, and checks whether it exceeds a certain threshold. I don't understand what should I exactly add to that, or if I should include something that is proper to temporal interpolation. Can you please further elaborate on this?
  3. Should I include changes so that the size and position of the said window can be set in temporal interpolation options in evaluate.py?
  4. Would you like changes in the code to make "temporal interpolation applied to a global frame" possible?
clessig commented 2 months ago
  1. I think the point was to plot the MSE as a function of the time steps that is masked; it should increase away from the available information
  2. Yes, and please also add/make sure we can mask multiple consecutive tokens (and then we can also look at the error under 2. for one, two, three masked tokens)
  3. I think global_forecast with masking_mode = temporal_interpolation should do but please double-check that it is working
sbAsma commented 2 months ago
  1. Multiple consecutive tokens works well. I'll take care of the rest of the todos here
sbAsma commented 1 month ago
  1. I think the point was to plot the MSE as a function of the time steps that is masked; it should increase away from the available information

    1. Yes, and please also add/make sure we can mask multiple consecutive tokens (and then we can also look at the error under 2. for one, two, three masked tokens)

    2. I think global_forecast with masking_mode = temporal_interpolation should do but please double-check that it is working

When trying to get temporal interpolation on a global scale I set the following mode and options

  mode, options = 'global_forecast', { 
                                      'dates' : [[2021, 2, 10, 12]] ,
                                      'masking_mode': 'temporal_interpolation',
                                      'idx_time_mask': [5,6,7],
                                      'token_overlap' : [0, 0],
                                      'forecast_num_tokens' : 2,
                                      'with_pytest' : True }

I am not getting temporal interpolation, I am only getting a global forecast. The current branch is rebased from iluise/head @iluise @clessig any idea on what to change?

iluise commented 1 month ago

global_forecast sets data from here: https://github.com/clessig/atmorep/blob/main/atmorep/datasets/multifield_data_sampler.py#L280

but I think the problem is here: https://github.com/clessig/atmorep/blob/main/atmorep/training/bert.py#L35-L40 you might need to create a new mode called like global_temporal_interpolation and call the proper masking option here.