by Eric Hartford
Quick divergence in loss, training gemma-2-2b with adamw-fused vs grokadamw
GrokAdamW is a novel optimizer designed to enhance AI training by combining the strengths of Grokfast (a technique for accelerating "grokking" in deep learning models) with the robustness and efficiency of the AdamW optimizer. It's particularly useful for models exhibiting delayed generalization, where performance on validation data improves significantly after a period of overfitting to the training data.
Update: This optimizer was used to train the awesome tiny model nisten/Biggie-SmoLlm-0.15B-Base
This implementation was inspired by the following papers:
Grokfast: Accelerated Grokking by Amplifying Slow Gradients
Lee, J., Kang, B. G., Kim, K., & Lee, K. M. (2024).
arXiv:2405.20233 [cs.LG].
https://doi.org/10.48550/arXiv.2405.20233
Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets
Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022).
arXiv:2201.02177 [cs.LG].
https://doi.org/10.48550/arXiv.2201.02177
Decoupled Weight Decay Regularization
Loshchilov, I., & Hutter, F. (2019).
arXiv:1711.05101 [cs.LG].
https://doi.org/10.48550/arXiv.1711.05101
Grokking is a phenomenon where deep learning models achieve sudden generalization after a long period of overfitting. Research suggests that this delayed generalization is related to the slow-varying components of gradients during training. Grokfast, inspired by this research, accelerates grokking by amplifying these slow-varying gradients.
GrokAdamW builds upon this concept by integrating Grokfast's adaptive frequency amplification into the AdamW optimization algorithm. It introduces several key innovations:
Core AdamW Updates: For each layer l, parameter p, and training step t:
Grokfast Integration:
grokking_signal_fns
)Optional Gradient Clipping:
gradient_clipping
> 0:
torch.nn.utils.clip_grad_norm_(parameters, gradient_clipping)
You can easily install GrokAdamW using pip:
pip install grokadamw
import torch
import torch.nn as nn
from grokadamw import GrokAdamW
# Define your model
model = nn.Linear(10, 1)
# Define your grokking signal function(s)
def grokking_signal_fn(training_loss: float, validation_loss: float) -> float:
if training_loss == 0:
return 0.0 # Avoid division by zero
return (validation_loss - training_loss) / training_loss
# Initialize GrokAdamW optimizer
optimizer = GrokAdamW(model.parameters(), lr=1e-3, grokking_signal_fn=grokking_signal_fn)
# Training loop
for epoch in range(num_epochs):
# ... [Your training code] ...
# Calculate validation loss (val_loss)
# Perform optimization step
loss = optimizer.step(closure=lambda: your_loss_function(model, data))
GrokAdamW supports standard AdamW parameters (lr
, betas
, eps
, weight_decay
) and additional parameters for Grokfast:
alpha_init
: Initial momentum for the EMA filter (default: 0.98)lamb
: Amplification factor for the filtered gradients (default: 2.0)gamma
: Layer-wise momentum decay rate (default: 0.1)grokking_signal_fns
: A list of functions that each return a scalar grokking signal (optional)grokking_signal_decay_rate
: Decay rate for adjusting alpha based on the grokking signal (default: 0.1)gradient_clipping
: Maximum norm for gradient clipping (default: 1.0, set to 0 to disable)Grokking Signal Functions Not Providing Useful Signals:
Issues with Gradient Clipping:
Unexpected Behavior with Layer-wise Momentum Decay:
gamma
or individual layer hyperparameters accordingly.Monitoring Grokking Signal and Alpha Values:
GrokAdamW is an ongoing research project. Your feedback and contributions are welcome! Please feel free to submit issues, feature requests, or pull requests. For more details, see our CONTRIBUTING.md file.
GrokAdamW is licensed under the Apache 2.0 License. See the LICENSE file for more details.