If you set the grokfast_after_step argument to anything other than None or 0, in GrokFastAdamW, the optimizer will crash at exactly that step. If set grokfast_after_step=512, for example, training will crash at step 511.
To Reproduce
OS : Linux
PyTorch version : 2.4.1
Python version : 3.12.6
Log
Error occurred at: 2024-10-11 08:34:34
Script: /home/crow/repos/praxis/run.py
Exception Type: KeyError
Exception Value: 'grok_exp_avg'
Traceback:
File "/home/crow/repos/praxis/run.py", line 1202, in <module>
trainer.fit(
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
self.fit_loop.run()
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
self._optimizer_step(batch_idx, closure)
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
return optimizer.step(closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py", line 130, in wrapper
return func.__get__(opt, opt.__class__)(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/optim/optimizer.py", line 484, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/crow/repos/praxis/venv/lib/python3.12/site-packages/pytorch_optimizer/optimizer/grokfast.py", line 217, in step
grok_exp_avg = state['grok_exp_avg']
~~~~~^^^^^^^^^^^^^^^^
Expected behavior
I would expect EMA calculation to begin at step 512. I would not expect the optimizer to crash.
Describe the bug
If you set the
grokfast_after_step
argument to anything other thanNone
or0
, inGrokFastAdamW
, the optimizer will crash at exactly that step. If setgrokfast_after_step=512
, for example, training will crash at step 511.To Reproduce
Log
Expected behavior
I would expect EMA calculation to begin at step 512. I would not expect the optimizer to crash.