This is an initial attempt at optimizing the update step of the GaLore Adam optimizer.
Overview
The GaLoreAdam optimizer introduces additional ops to the traditional adam update step.
Specifically:
grad is projected to low rank --> additional matmul
adam states are updated with grad elementwise (same as Adam except in low-rank)
normalized grad is projected to full rank --> additional matmul
params are updated with the normalized full rank grad
Implementation
Various fusions were attempted across 2 kernel implementations:
Fused
Steps 1 & 2 are fused: the adam state updates are loaded and updated (inplace) during the first matmul
Steps 3 & 4 are fused: the param update is folded as an epilogue into the second matmul
Hybrid
Step 1 is performed using standard torch matmul (i.e., cuBlas)
Step 2 is fused as an elementwise kernel
Steps 3 & 4 per Fused
Performance
Below are benchmarks for various kernels:
torch - reference torch implementation where each of the steps are implemented verbatim per above
hybrid - see above
fused - see above
compiled - torch reference implementation compiled using torch.compile with fullgraph=True and mode="max-autotune".
Configs for each benchmark are the grad (param) shape, dtype of grad and adam states, and allow_tf32, whether torch and triton matmuls are allowed to use TF32 tensor cores (see Discussion).
Grad shape: 4096x4096, dtype: torch.float32, allow_tf32: False
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
Max err: 0.000000 Relative err: 0.000001
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000001
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
Max err: 0.000000 Relative err: 0.000000
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000000
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True
Kernel: fused
Accuracy:
-> adam state - running grad mean:
Max err: 0.000845 Relative err: 0.001152
-> adam state - running grad var:
Max err: 0.000162 Relative err: 0.000161
-> params (after update):
Max err: 0.000000 Relative err: 0.000001
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: fused
Accuracy:
-> adam state - running grad mean:
Max err: 0.000003 Relative err: 0.000004
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000000
Discussion
Down Projection GEMM Shape
The motivation for the hybrid approach is the unconventional matrix shapes of the down projection (Step 1):
The projection is always done such that the larger dimension of the grad matrix is maintained while other is projected to low rank per the GaLore algorithm
E.g., if M >= N, the GEMM is of shape (M x N) x (N x rank) = (M x rank), (rank x M) x (M x N) = (rank x N) otherwise
Since {M, N} >> rank by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny)
This does not fit cleanly into the split-k / parallel reductionGEMM paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension.
Consequently, I had trouble finding an optimal kernel config using tritonautotuner for the down projection step, despite tuning across many compute and io-bound configs (see fused.triton_utils.kernels.matmul.py).
Benchmarking triton-tuned matmul against default torch.matmul for these shapes showed worse performance, for torch.float32
Effect of TF32 tensor cores
allow_tf32: this has significant impact on relative performance of triton vs torch matmuls:
Quick benchmarks of the downprojection matmul show that:
with allow_tf32=True for both, triton exhibits ~1.30x performance improvement over torch.
with allow_tf32=False, performance of triton degrades significantly to ~.50x of torch.
See this torch note for more details on this feature.
Note: This might be less of a concern given this incoming triton PR, which implements a fast TF32 trick that improves both performance and accuracy.
Repro
tests/test_fused_kernels.py is a CLI that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs.
Examples
Accuracy
Test accuracy of torch vs hybrid for M=4096, N=4096, rank=128, and tf32 switched on:
Note: Passing in the additional flag --verbose will show triton autotuning logs -- I customized the triton autotuner spit out configs and other details.
Test Env
GPU Device Props:
Name: NVIDIA RTX A6000
CC: 86
Total_memory: 48676MB
SM count: 84
Torch: 2.3.0.dev20240310+cu118
Triton: 3.0.0
Next Steps
[ ] Implement FusedGaLoreOptimizer
[ ] Cutlass - given fixed GEMM shape, experiment with Cutlass GEMMs (split-k, stream-k, fast tensorops). Interestingly, profiling torch.matmul for down projection shows that cuBlas dispatches to a Cutlass kernel of shape 128x128x16.
[ ] Repeat with AdamW8bit
[ ] More detailed analysis of torch.compile performance
Fused GaLore Adam (WIP)
Various fused implementations of
Adam
update step per Gradient Low-Rank ProjectionThis is an initial attempt at optimizing the update step of the
GaLore Adam
optimizer.Overview
The
GaLore
Adam
optimizer introduces additional ops to the traditionaladam
update step.Specifically:
grad
is projected to low rank --> additional matmuladam
states are updated withgrad
elementwise (same asAdam
except in low-rank)grad
is projected to full rank --> additional matmulparams
are updated with the normalized full rank gradImplementation
Various fusions were attempted across 2 kernel implementations:
Fused
adam
state updates are loaded and updated (inplace) during the firstmatmul
matmul
Hybrid
torch matmul
(i.e.,cuBlas
)Fused
Performance
Below are benchmarks for various kernels:
torch
- referencetorch
implementation where each of the steps are implemented verbatim per abovehybrid
- see abovefused
- see abovecompiled
-torch
reference implementation compiled usingtorch.compile
withfullgraph=True
andmode="max-autotune"
.Configs for each benchmark are the
grad (param)
shape,dtype
ofgrad
andadam
states, andallow_tf32
, whethertorch
andtriton
matmuls are allowed to useTF32
tensor cores (seeDiscussion
).Grad shape
:4096x4096
,dtype
:torch.float32
,allow_tf32
:False
Grad shape
:4096x4096
,dtype
:torch.float32
,allow_tf32
:True
Grad shape
:4096x11008
,dtype
:torch.float32
,allow_tf32
:False
Grad shape
:4096x11008
,dtype
:torch.float32
,allow_tf32
:True
Accuracy
Comparison to reference
torch
implementation:Discussion
Down Projection GEMM Shape
The motivation for the
hybrid
approach is the unconventional matrix shapes of the down projection (Step 1):grad
matrix is maintained while other is projected to low rank per theGaLore
algorithmM >= N
, the GEMM is of shape (M x N
) x (N x rank
) = (M x rank
), (rank x M
) x (M x N
) = (rank x N
) otherwise{M, N} >> rank
by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny)split-k / parallel reduction
GEMM
paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension.triton
autotuner
for the down projection step, despite tuning across many compute and io-bound configs (seefused.triton_utils.kernels.matmul.py
).triton
-tunedmatmul
against defaulttorch.matmul
for these shapes showed worse performance, fortorch.float32
Effect of
TF32
tensor coresallow_tf32
: this has significant impact on relative performance oftriton
vstorch
matmuls:matmul
show that:allow_tf32=True
for both, triton exhibits~1.30x
performance improvement overtorch
.allow_tf32=False
, performance oftriton
degrades significantly to~.50x
oftorch
.See this
torch note
for more details on this feature.Note: This might be less of a concern given this incoming triton PR, which implements a fast
TF32
trick that improves both performance and accuracy.Repro
tests/test_fused_kernels.py
is aCLI
that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs.Examples
Accuracy
Test accuracy of
torch
vshybrid
forM=4096
,N=4096
,rank=128
, andtf32
switched on:Benchmark
Benchmark across all kernels without
tf32
:Additional options
Note: Passing in the additional flag
--verbose
will showtriton
autotuning logs -- I customized thetriton
autotuner spit out configs and other details.Test Env
NVIDIA RTX A6000
86
48676MB
84
2.3.0.dev20240310+cu118
3.0.0
Next Steps
FusedGaLoreOptimizer
Cutlass
- given fixed GEMM shape, experiment withCutlass
GEMMs (split-k
,stream-k
, fasttensorops
). Interestingly, profilingtorch.matmul
for down projection shows thatcuBlas
dispatches to aCutlass
kernel of shape128x128x16
.AdamW8bit
torch.compile
performance