AGI-Arena / MARS

The official implementation of MARS: Unleashing the Power of Variance Reduction for Training Large Models
https://github.com/AGI-Arena/MARS
Apache License 2.0
128 stars 22 forks source link
fine-tuning large-language-models optimization-algorithms optimizer pretraining

MARS: Unleashing the Power of Variance Reduction for Training Large Models

This repository contains the official code for the paper MARS: Unleashing the Power of Variance Reduction for Training Large Models.

Authors: Huizhuo Yuan*, Yifeng Liu*, Shuang Wu, Xun Zhou, Quanquan Gu

đź”” NEWS

About MARS

MARS (Make vAriance Reduction Shine) is a unified optimization framework designed to address the inherent challenges of training large models. Traditional adaptive gradient methods like Adam and AdamW often suffer from high stochastic gradient variance, while variance reduction techniques have struggled to gain practical impact in deep learning. At its core, MARS comprises two major components: (1) a scaled stochastic recursive momentum, which provides a variance-reduced estimator of the full gradient for better gradient complexity; and (2) the preconditioned update, which approximates the second-order Newton's method for better per-iteration complexity. By combining preconditioned gradient methods with variance reduction, MARS achieves the best of both worlds, accelerating the search for critical points in optimization.

The MARS framework is built on the following preconditioned variance-reduced updates

$$ \mathbf{c}_t = \nabla f(\mathbf{x}_t, \mathbf{\xi}_t)+\underbrace{{\color{red}\gammat} \frac{\beta{1}}{1-\beta{1}} \left(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)-\nabla f(\mathbf{x}_{t-1}, \mathbf{\xi}_t)\right)}{\text{scaled gradient correction}} $$

$$ \tilde{\mathbf{c}}_t = \text{Clip}(\mathbf{c}_t,1) = \begin{cases} \frac{\mathbf{c}_t}{\|\mathbf{c}_t\|_2} & \text{if } \|\mathbf{c}_t\|_2 > 1,\ \mathbf{c}_t & \text{otherwise}. \end{cases} $$

$$ \mathbf{m}_t = \beta1 \mathbf{m}_{t-1} + (1-\beta{1})\tilde{\mathbf{c}}_t $$

$$ \mathbf{x}_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\{\eta_t \left\langle \mathbf{m}t, \mathbf{x} \right\rangle + \frac{1}{2} \|\mathbf{x} - \mathbf{x}_t \|\{\mathbf{H}_t}^2\right\} $$

Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction.

Instantiations of MARS

Under the MARS framework, we provide three instantiations based on different Hessian matrix approximations: MARS-AdamW, MARS-Lion, and MARS-Shampoo. Please note that the hyperparameters in this framework are tuned on MARS-AdamW. When using other instantiations, it is essential to tune the hyperparameters—particularly the learning rates—for optimal performance.

MARS-AdamW

(Enable with mars_type="mars-adamw" in mars.py)

The Hessian matrix approximation is defined as:

$$ \mathbf{v}_t =\beta2 \mathbf{v}\{t-1}+(1-\beta_2) \big(\nabla f(\mathbf{x}_t, \mathbf{\xi}_t)\big)^2 $$

$$ \mathbf{H}_t := \sqrt{\text{diag}\Big(\mathbf{v}_t\Big)}\cdot \frac{1 - \beta_1^t}{\sqrt{1 - \beta_2^t}}. $$

MARS-Lion

(Enable with mars_type="mars-lion" in mars.py)

The Hessian matrix approximation is defined as:

$$ \mathbf{H}_t := \sqrt{\text{diag}(\mathbf{m}_t^2)}. $$

MARS-Shampoo

(Enable with mars_type="mars-shampoo" in mars.py)

The preconditioner can be seen as an orthogonal mapping operator:

$$ \mathbf{U}_t, \mathbf{\Sigma}_t, \mathbf{V}_t = \text{SVD}(\mathbf{G}_t),\qquad \mathbf{x}_{t+1} =\mathbf{x}_t-\eta_t\mathbf{U}_t\mathbf{V}_t^\top. $$

In practice, we use the Newton-Schulz iteration to accelerate and approximate the solution of SVD problem.

Performance of MARS Compared to Baselines

Experimental results for MARS are based on the MARS-AdamW instantiation, unless otherwise stated. In our experiments, gradients are calculated once per sample and per update (MARS-approx in our paper). Performing exact gradient computation with two evaluations per update, as in the exact form of MARS, can slightly enhance performance but at the cost of doubling the computational expense. For more details, refer to our paper.

MARS consistently outperforms AdamW and the Muon optimizer across GPT-2 models:

GPT-2 small GPT-2 medium GPT-2 large
Best Val Loss GPT-2 Small (5B tokens) GPT-2 Medium (5B tokens) GPT-2 Large (5B tokens) GPT-2 Small (20B tokens) GPT-2 Medium (20B tokens) GPT-2 Large (20B tokens) GPT-2 Small (50B tokens) GPT-2 Medium (50B tokens) GPT-2 Large (50B tokens)
AdamW 3.193 3.084 3.013 3.024 2.821 2.741 2.885 2.691 2.561
Muon 3.165 3.009 2.915 3.006 2.813 2.691 2.901 2.688 2.573
MARS-exact 3.107 TBD TBD 2.980 TBD TBD 2.847 TBD TBD
MARS-approx 3.108 2.969 2.876 2.981 2.763 2.647 2.849 2.636 2.518

Efficiency of MARS

The MARS algorithm can achieve better performance not only within the same number of training steps, but also within the same training time:

GPT-2 small GPT-2 medium GPT-2 large

Training GPT-2 from Scratch:

Install Dependencies

$ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb

Data Preparation

Prepare the OpenWebText data following nanoGPT:

$ python data/openwebtext/prepare.py

Start Training

To train a model using the MARS optimizer, run the following command:

$ torchrun --standalone --nproc_per_node=8 MARS/train_mars.py config/${your_config_file}

This command initiates the training of a GPT-2 model on the OpenWebText dataset using the MARS optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (${your_config_file}). These parameters can be adjusted directly in the configuration file or through the bash script.

Hyperparameter Details

Model Hyperparameters:

Optimizer Hyperparameters:

Training Hyperparameters:

For more detailed hyperparameter examples, refer to:


Reproducing Our Results

Reproducing GPT-2 Small (125M) Results

Training with MARS using

$ bash scripts/run_mars_small.sh

or

$ torchrun --standalone --nproc_per_node=8 \
      MARS/train_mars.py \
      config/train_gpt2_small_mars.py \
      --batch_size=15 \
      --gradient_accumulation_steps=4

Reproducing GPT2 Medium (355M) Results

Training with MARS using

$ bash scripts/run_mars_medium.sh

or

$ torchrun --standalone --nproc_per_node=8 \
      MARS/train_mars.py \
      config/train_gpt2_medium_mars.py \
      --batch_size=15 \
      --gradient_accumulation_steps=4

Reproducing GPT2 Large (770M) Results

Training with MARS using

$ bash scripts/run_mars_large.sh

or

$ torchrun --standalone --nproc_per_node=8 \
      MARS/train_mars.py \
      config/train_gpt2_large_mars.py \
      --batch_size=5 \
      --gradient_accumulation_steps=12

Reproducing Baseline Results

To reproduce the AdamW baseline:

bash scripts/run_adamw_{small/medium/large}.sh

To reproduce the Muon baseline following modded-nanogpt:

bash scripts/run_muon_{small/medium/large}.sh

Please adjust nproc_per_node, batch_size, and gradient_accumulation_steps accordingly if you use other hardware setup. Make sure their product equals 480.

Hyperparameters for GPT-2 models

Model Name Model Size lr for AdamW lr for Muon lr for MARS lr_1d for MARS wd for AdamW wd for Muon wd for MARS
GPT-2 small 125M 6e-4 2e-2 6e-3 3e-3 1e-1 0.0 1e-2
GPT-2 medium 355M 3e-4 1e-2 3e-3 1.5e-3 1e-1 0.0 1e-2
GPT-2 large 770M 2e-4 6.67e-3 2e-3 1e-3 1e-1 0.0 1e-2

Note that different hyperparameters may benefit different stages of training. For the GPT-2 Small model, our MARS optimizer is tuned to prioritize the best final validation performance. If faster progress in the earlier stages of training is desired, using wd=1e-3 may provide better results.

Customized Training

To build your own training pipeline on other architectures and datasets, use the following template as an example:

import torch
import torch.nn.functional as F
from mars import MARS

# init model loss function and input data
model = Model()
data_loader = ...

# init the optimizer
optimizer = MARS(model.parameters(), lr=1e-3, betas=(0.9, 0.95), gamma=0.025)

total_bs = len(data_loader)
bs = total_bs * block_size
k = 10
iter_num = -1

# training loop
for epoch in range(epochs):
    for X, Y in data_loader:
        # standard training code
        logits, loss = model(X, Y)
        loss.backward()
        optimizer.step(bs=bs)
        optimizer.zero_grad(set_to_none=True)
        optimizer.update_last_grad()
        iter_num += 1

Star History

Star History Chart

Citation

If you find this repo useful for your research, please consider citing the paper

@article{yuan2024mars,
  title={MARS: Unleashing the Power of Variance Reduction for Training Large Models},
  author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan},
  journal={arXiv preprint arXiv:2411.10438},
  year={2024}
}

Acknowledgements

This repo is built upon nanoGPT, levanter and Sophia, we thank the authors for their great work!