kozistr / pytorch_optimizer

optimizer & lr scheduler & loss function collections in PyTorch
https://pytorch-optimizers.readthedocs.io/en/latest/
Apache License 2.0
242 stars 21 forks source link

'grokfast_after_step' argument causes the optimizer to crash #279

Closed Vectorrent closed 1 month ago

Vectorrent commented 1 month ago

Describe the bug

If you set the grokfast_after_step argument to anything other than None or 0, in GrokFastAdamW, the optimizer will crash at exactly that step. If set grokfast_after_step=512, for example, training will crash at step 511.

To Reproduce

Log

Error occurred at: 2024-10-11 08:34:34
Script: /home/crow/repos/praxis/run.py
Exception Type: KeyError
Exception Value: 'grok_exp_avg'
Traceback:
  File "/home/crow/repos/praxis/run.py", line 1202, in <module>
    trainer.fit(
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
    return func.__get__(opt, opt.__class__)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/pytorch_optimizer/optimizer/grokfast.py", line 217, in step
    grok_exp_avg = state['grok_exp_avg']
                   ~~~~~^^^^^^^^^^^^^^^^

Expected behavior

I would expect EMA calculation to begin at step 512. I would not expect the optimizer to crash.

kozistr commented 1 month ago

@Vectorrent thanks for reporting. I'm now working on fixing it.