After initial profiling, and as outlined above, the DoRA update layer requires multiple kernels.
In order of compute intensity:
4 GEMMs:
x @ base_weight
lora_B(lora_A(x))
lora_B.weight @ lora_A.weight
1 Reduction: 2-norm
4 Elementwise: matrix-matrix additions (2) and broadcasted matrix-vector multiplications (2).
While torch.compile (and CUDA graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel.
Note that the lora_B.weight @ lora_A.weight has a specific shape, where K << {M, N}. That is, lora_B.weight is out_features x lora_rank and lora_A.weight is lora_rank x in_features.
Since lora_rank is typically < 64 while {in,out}-features are typically > 4096 (e.g., Llama MLP / QKV projections), this GEMM is inefficient, since each CTA loads a block, only to perform a few MAC iterations given small K.
Moreover, note that the result of this GEMM is not needed -- we only need the 2-norm of this computation.
Combining these two observations, we can write a fused kernel where:
Each CTA computes an entire row of the output matrix, with the key assumption that BLOCK_K = K. That is, each CTA does a single MAC iteration to compute a BLOCK_M x BLOCK_N output, then iterates across dimension N.
Since each block processes an entire row, we can now additionally fuse a grid-wise reduction along axis=1 into the kernel. In this case, we can directly fold the 2-norm computation into the GEMM.
As an added bonus, we can also include the base_weight elementwise addition and magnitude_vector multiplication into the GEMM epilogue.
Altogether, this allows us to fuse the following computation into a single kernel:
Additionally, instead of computing the base layer output before the DoRA / LoRA updates, we can compute the latter (loRA layer and magnitude_scale) first, and fold these into the epilogue of the base layer GEMM:
#DoRA / LoRA updates
lora_out = lora_B(lora_A(x))
magnitude_scale = magnitude_vector / (base_weight + lora_B.weight @ lora_A.weight).norm(p=2, dim=1)
#This is now a single kernel
final_out = (x @ base_weight.T + lora_out) * magnitude_scale
The fused kernels can be used to implement DoRA / QDoRA layers.
A reference implementation is provided in dora.dora_layer.DoRALinear, which defines a base QDoRA linear layer (with a stub dequantize method) along with corresponding BNBDoRALinear and HQQDoRALinear subclasses, which override dequantize with their respective methods.
Run with flag --kernel set to one of {dora-colnorm,dora-mm-epilogue}, to benchmark the respective fused kernels against a reference torch / torch.compile implementation, or --kernel=dora-full to bench against the entire DoRA computation.
Additionally, passing either --kernel={dora-bnb, dora-hqq} will bench a reference QDoRA layer against their fused implementations.
Fused DoRA Kernels
Fused DoRA layer implementation that reduces number of individual kernels from ~10 -> 5.
Contents
Background
DoRA (weight-decomposed low-rank adaptation) is a variant of LoRA that decomposes the LoRA update into magnitude and vector components.
The DoRA layer is roughly as follows:
where:
lora_A
andlora_B
arelinear
layers with weight shapesrank x in_features
andout_features x rank
.base_weight
is the weight of the frozenlinear
layer of shapeout_features x in_features
.magnitude_vector
is initialized as the columnwise2-norm
of the frozen weight (shapeout-features
).x
are the inputs of shapebatch_size x seqlen x in_features
Optimization
After initial profiling, and as outlined above, the
DoRA
update layer requires multiple kernels.In order of compute intensity:
x @ base_weight
lora_B(lora_A(x))
lora_B.weight @ lora_A.weight
2-norm
While
torch.compile
(andCUDA
graphs) can partially mitigate the overhead of multiple small kernels and improve compute efficiency of individual kernels, there remains room for additional optimization by reordering the computations to facilitate fusions, and more importantly, exploiting the unique shapes of the GEMMs, thereby decreasing the number of kernel launches and increasing the compute intensity of each kernel.Key Contributions
1 - Small K Fused Kernel
Note that the
lora_B.weight @ lora_A.weight
has a specific shape, whereK << {M, N}
. That is,lora_B.weight
isout_features x lora_rank
andlora_A.weight
islora_rank x in_features
.Since
lora_rank
is typically< 64
while{in,out}-features
are typically> 4096
(e.g.,Llama MLP / QKV projections
), thisGEMM
is inefficient, since eachCTA
loads a block, only to perform a fewMAC
iterations given smallK
.Moreover, note that the result of this
GEMM
is not needed -- we only need the2-norm
of this computation.Combining these two observations, we can write a fused kernel where:
CTA
computes an entire row of the output matrix, with the key assumption thatBLOCK_K = K
. That is, eachCTA
does a single MAC iteration to compute aBLOCK_M x BLOCK_N
output, then iterates across dimensionN
.axis=1
into the kernel. In this case, we can directly fold the2-norm
computation into theGEMM
.base_weight
elementwise addition andmagnitude_vector
multiplication into theGEMM
epilogue.Altogether, this allows us to fuse the following computation into a single kernel:
2 - Fused Epilogue GEMM
Additionally, instead of computing the base layer output before the
DoRA / LoRA
updates, we can compute the latter (loRA layer
andmagnitude_scale
) first, and fold these into the epilogue of the base layerGEMM
:Usage
The fused kernels can be used to implement
DoRA
/QDoRA
layers.A reference implementation is provided in
dora.dora_layer.DoRALinear
, which defines a baseQDoRA
linear layer (with a stubdequantize
method) along with correspondingBNBDoRALinear
andHQQDoRALinear
subclasses, which overridedequantize
with their respective methods.Example
See
test/test_dora_layer.py
andbenchmarks/dora_bench.py
for more detailed usage.Also, note that these are reference implementations and are not fully optimized. See Next Steps for follow-up plans.
Tests
See
test/dora/test*
, for correctness checks of the fused kernels and layers.Benchmarks
See
benchmarks/dora_bench.py
.Run with flag
--kernel
set to one of{dora-colnorm,dora-mm-epilogue}
, to benchmark the respective fused kernels against a referencetorch
/torch.compile
implementation, or--kernel=dora-full
to bench against the entireDoRA
computation.Additionally, passing either
--kernel={dora-bnb, dora-hqq}
will bench a referenceQDoRA
layer against their fused implementations.Profiling
The reference
DoRALinear
layer described above also has an instrumented forward pass with annotated regions for each of theDoRA
ops.An example script for running a profiled forward pass is provided in
dora/dora_profile.py
.To run with
torch.profiler
:which outputs chrome trace to default folder
dora_profiles
.To run with
nsys
:where
...
are other desirednsys
options.Note that
--capture_range=cudaProfilerApi
is required.Next Steps
torch.compile
, re-ordering computations, etc.torch.autograd.Function
FSDP-LoRA
)triton
autotunergalore
,hqq
, anddora
can now be refactored into single module. Separate PR?