jiaweizzhao / GaLore

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection
Apache License 2.0
1.24k stars 131 forks source link

[WIP] Fused Adam Triton Kernels #29

Open jeromeku opened 3 months ago

jeromeku commented 3 months ago

Fused GaLore Adam (WIP)

Various fused implementations of Adam update step per Gradient Low-Rank Projection

This is an initial attempt at optimizing the update step of the GaLore Adam optimizer.

Overview

The GaLore Adam optimizer introduces additional ops to the traditional adam update step.

Specifically:

  1. grad is projected to low rank --> additional matmul
  2. adam states are updated with grad elementwise (same as Adam except in low-rank)
  3. normalized grad is projected to full rank --> additional matmul
  4. params are updated with the normalized full rank grad

Implementation

Various fusions were attempted across 2 kernel implementations:

Performance

Below are benchmarks for various kernels:

Configs for each benchmark are the grad (param) shape, dtype of grad and adam states, and allow_tf32, whether torch and triton matmuls are allowed to use TF32 tensor cores (see Discussion).

Grad shape: 4096x4096, dtype: torch.float32, allow_tf32: False

Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  0.560128  0.347136  0.505856  0.534528
1   64.0  0.627712  0.404480  0.600960  0.615424
2  128.0  0.825232  0.583168  0.985072  0.833536
3  256.0  1.378304  1.126400  1.489920  1.375232
4  512.0  2.286080  2.101760  2.969600  2.302976

Grad shape: 4096x4096, dtype: torch.float32, allow_tf32: True

Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  0.540672  0.321536  0.316416  0.508928
1   64.0  0.612240  0.337728  0.345024  0.538624
2  128.0  0.640000  0.395264  0.393216  0.693248
3  256.0  0.777216  0.489472  0.548784  1.102848
4  512.0  1.216512  0.864256  0.960512  1.968128

Grad shape: 4096x11008, dtype: torch.float32, allow_tf32: False

Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  1.538672  0.915456  0.835584  1.364032
1   64.0  1.546240  0.940032  1.022976  1.486848
2  128.0  2.116608  1.498112  1.613312  2.098176
3  256.0  3.423744  2.719744  2.881536  3.227136
4  512.0  5.499904  5.036544  5.450752  5.508096

Grad shape: 4096x11008, dtype: torch.float32, allow_tf32: True

Median times (ms):
    rank     torch    hybrid     fused  compiled
0   32.0  1.413120  0.871424  0.817152  1.353184
1   64.0  1.489920  0.916480  0.854016  1.389568
2  128.0  1.679360  0.996352  1.005568  1.563648
3  256.0  2.152448  1.415168  1.470464  2.185216
4  512.0  3.210240  2.460672  2.580480  3.477504
Accuracy

Comparison to reference torch implementation:

Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
  Max err: 0.000000 Relative err: 0.000001
-> adam state - running grad var:
  Max err: 0.000002 Relative err: 0.000002
-> params (after update):
  Max err: 0.000000 Relative err: 0.000001
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
  Max err: 0.000000 Relative err: 0.000000
-> adam state - running grad var:
  Max err: 0.000002 Relative err: 0.000002
-> params (after update):
  Max err: 0.000000 Relative err: 0.000000
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True
Kernel: fused
Accuracy:
-> adam state - running grad mean:
  Max err: 0.000845 Relative err: 0.001152
-> adam state - running grad var:
  Max err: 0.000162 Relative err: 0.000161
-> params (after update):
  Max err: 0.000000 Relative err: 0.000001
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: fused
Accuracy:
-> adam state - running grad mean:
Max err: 0.000003 Relative err: 0.000004
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000000

Discussion

Down Projection GEMM Shape

The motivation for the hybrid approach is the unconventional matrix shapes of the down projection (Step 1):

Effect of TF32 tensor cores

allow_tf32: this has significant impact on relative performance of triton vs torch matmuls:

See this torch note for more details on this feature.

Note: This might be less of a concern given this incoming triton PR, which implements a fast TF32 trick that improves both performance and accuracy.

Repro

tests/test_fused_kernels.py is a CLI that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs.

Examples

Accuracy

Benchmark

Additional options

  python tests/test_fused_kernels.py --help

Note: Passing in the additional flag --verbose will show triton autotuning logs -- I customized the triton autotuner spit out configs and other details.

Test Env

Next Steps