Open jeromeku opened 8 months ago
@jeromeku I will get to reviewing GPTQ - sorry on the delay!!
@danielhanchen
Thanks -- would be helpful to have a step-by-step breakdown of where the memory savings are coming from, i.e., an ablation study.
Is there interest in faster inference kernels, or is the focus primarily on the training side?
@jeromeku For Mistral itself: https://unsloth.ai/blog/mistral-benchmark
Gemma's VRAM reduction should be similar to our breakdown for Mistral.
For inference for Gemma - I did make it 2x faster, but it's mainly cobbling up ideas from vLLM and other packages, so I only spent 1 week on it :) The goal is to merge GPT Fast and other ideas like EAGLE to make inference faster :)
@danielhanchen
I'd be interested in contributing on the inference front -- let's create a priority list of ideas for implementation?
@jeromeku That'll be cool!! :) We can collab either via Github or async on our Discord - whatever suites you :)
@danielhanchen
Looking forward to it!
What's top of mind currently? Perhaps we can draw up a roadmap (if one doesn't already exist).
@jeromeku Oh ye a roadmap would be nice - don't actually have one for inference specifically :)
@danielhanchen
You mentioned integrating ideas from fastGPT and EAGLE, what others did you have in mind?
What's on the roadmap for fine-tuning / training -- architectures, algorithms, etc.? Asking so I know what literature / code to review.
@jeromeku In terms of inference specifically:
I might have more, but those are from the top of my head.
For training / finetuning:
Just a brain dump!
@danielhanchen
Love it.
Inference:
Training:
Let me know what I should prioritize.
Also, can you expand more on Triton GEMV? What kind of horizontal / vertical fusions to target?
Oh so GEMV is generally OK I guess - the issue is the dequant step merged in (ie what you were doing with GPTQ, except its not matrix matrix mult but matrix vector mult) this allows different optimizations - ie is blocked mm better or is column or is row wise mv better? It depends on the cache footprint
But the goal is can we somehow merge X @ Wq, X @ Wk, X @ Wv
together with RoPE and attention and everything into 1 large kernel
If I understand correctly:
Can you point me to the current GEMV implementation? Need a minimal implementation / testbed for benchmarking purposes.
Oh for inference, you method of fusing the dequant step inside the kernel is actually ideal! For training its not, since CUBLAS is relatively smart in data movements.
An ideal kernel for GEMV ie vector * matrix kernel normally is done via:
However a more optimal procedure is to split the reductions into 4 blocks by using atomic_add
. It in fact can be say reduction columns of 4, but say blocks of 24, and cycling using the modulus function.
A final reduction will need to be made at the end.
The current GEMV implementation will be probably the one in Fast-GPT although I haven't inspected it myself yet.
The hardest is the folding in of Bitsandbytes int4 which is a nightmare, since the blocksize is lopsided ie not whole integer multiple, which is a nightmare for cache optimality.
Another approach people do is row wise
which again can be done in parallel with a reduction as i described above
@danielhanchen Ok - so I'm clear on objectives:
For training / finetuning:
@danielhanchen Obligatory request for Multi GPU XD
@jeromeku Extremely sorry on the delay - yep sounds right! :) @nivibilla Yep!
@danielhanchen
Is the issue with the existing bitsandbytes
gemv
the fact that it's CUDA only?
@jeromeku Yes that can be one of the main issues - the other is folding it inside other kernels ie say 1 singular kernel can become too complex to do.
The main issue I still see with 1 kernel, so maybe I'm overthinking, is every new op requires synchronization, so maybe we should rather rely on torch.compile
with CUDAGraphs
to reduce the CPU overhead in between.
I'd imagine there is an optimization spectrum:
torch.compile
entire graph with appropriate inductor settings to maximize fusion / reduce overhead torch
cudagraph
APIs to glue things togetherWill make a quick pass at implementing bnb
dequant gemv in triton to see how performance compares.
Cutlass also enables some flexibility with bespoke gemm
and fusions but is again cuda
only. Let me know if this is of interest.
@jeromeku Oh ye let's try be device agnostic :)) compile is OK, but I guess handwritting is best :) We then can use CUDAGraphs manually
@danielhanchen
A few updates:
GaLore
-- ran some initial experiments to fuse the GaLore
Adam
update step -- see PRtriton
4-bit bnb
dequant kernel of interest?mixtral
. @jeromeku Fantastic work as always!! very very cool on fusing Adam and Galore!! Love this!
Oh on Mixtral - https://github.com/shawntan/scattermoe/tree/main/scattermoe :) Was reading up on this as well :)
On BnB dequant - I'll have a look first at it :) But you're more than happy to do it if you want :)
@danielhanchen
GaLore
into unsloth
? Planning on working on an Adam8bit
version.bnb
dequantReally excited about optimized kernels for inference!
Worth looking at https://github.com/zeux/calm - where the forward pass is implemented as a single cuda kernel
Uses fp8 rather than int4/8 quantization.
@danielhanchen
In the unsloth Gemma intro blogpost, you mention VRAM increase due to larger
MLP
size inGemma
compared toLlama
andMistral
, and show a graph demonstrating decreased memory usage when runningunsloth
vs.HF
andFA2
:HF
vsFA2
vsunsloth
graph? Is it inference or training?Curious what optimizations are leading to memory decrease -- quantization, autograd efficiency, etc.