clessig / atmorep

AtmoRep model code
MIT License
36 stars 9 forks source link

[BUG] restore with_mixed_precision = False option #48

Open iluise opened 1 day ago

iluise commented 1 day ago

Describe the bug

When turning the flag with_mixed_precision = False the code crashes when computing the attentions here.

I tried to solve it with this solution:

 with torch.cuda.amp.autocast():

but I'm getting stuck later when calling the backward propagation. see below.

Error

Loaded model id = 2e3160tt at epoch = 32.
Loaded run '2e3160tt' at epoch 32.
Number of trainable parameters: 530,787,856
33 : 14:06:32 :: batch_size = 96, lr = 1e-05
velocity_u: 0.15534581243991852; velocity_v: 0.16358940303325653; specific_humidity: 0.15310746431350708; velocity_z: 0.3288305997848511; temperature: 0.23762983083724976; 
Traceback (most recent call last):
  File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/train_multi.py", line 303, in <module>
    train_continue( wandb_id, epoch, Trainer, epoch_continue)
  File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/train_multi.py", line 81, in train_continue
    trainer.run( epoch_continue)
  File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/trainer.py", line 187, in run
    self.train( epoch)
  File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/trainer.py", line 237, in train
    self.grad_scaler.scale(loss).backward()
  File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Half
[rank0]: Traceback (most recent call last):
[rank0]:   File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/train_multi.py", line 303, in <module>
[rank0]:     train_continue( wandb_id, epoch, Trainer, epoch_continue)
[rank0]:   File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/train_multi.py", line 81, in train_continue
[rank0]:     trainer.run( epoch_continue)
[rank0]:   File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/trainer.py", line 187, in run
[rank0]:     self.train( epoch)
[rank0]:   File "/gpfs/home/cern/cern890661/atmorep/atmorep_May24/atmorep/core/trainer.py", line 237, in train
[rank0]:     self.grad_scaler.scale(loss).backward()
[rank0]:   File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/_tensor.py", line 525, in backward
[rank0]:     torch.autograd.backward(
[rank0]:   File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank0]:     _engine_run_backward(
[rank0]:   File "/gpfs/scratch/ehpc03/pyenvs/pyenv/lib64/python3.9/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank0]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank0]: RuntimeError: Found dtype Float but expected Half

Hardware and environment:

clessig commented 19 hours ago

Yes, flash attention only works with half precision. You can try to cast it explicitly at input and cast it back at output. But can you try setting flash_attention=False as well?