unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
17.85k stars 1.24k forks source link

Faster Inference & Training Roadmap #226

Open jeromeku opened 8 months ago

jeromeku commented 8 months ago

@danielhanchen

In the unsloth Gemma intro blogpost, you mention VRAM increase due to larger MLP size in Gemma compared to Llama and Mistral, and show a graph demonstrating decreased memory usage when running unsloth vs. HF and FA2:

Curious what optimizations are leading to memory decrease -- quantization, autograd efficiency, etc.

danielhanchen commented 8 months ago

@jeromeku I will get to reviewing GPTQ - sorry on the delay!!

jeromeku commented 8 months ago

@danielhanchen

Thanks -- would be helpful to have a step-by-step breakdown of where the memory savings are coming from, i.e., an ablation study.

Is there interest in faster inference kernels, or is the focus primarily on the training side?

danielhanchen commented 8 months ago

@jeromeku For Mistral itself: https://unsloth.ai/blog/mistral-benchmark image

Gemma's VRAM reduction should be similar to our breakdown for Mistral.

For inference for Gemma - I did make it 2x faster, but it's mainly cobbling up ideas from vLLM and other packages, so I only spent 1 week on it :) The goal is to merge GPT Fast and other ideas like EAGLE to make inference faster :)

jeromeku commented 8 months ago

@danielhanchen

I'd be interested in contributing on the inference front -- let's create a priority list of ideas for implementation?

danielhanchen commented 8 months ago

@jeromeku That'll be cool!! :) We can collab either via Github or async on our Discord - whatever suites you :)

jeromeku commented 8 months ago

@danielhanchen

Looking forward to it!

What's top of mind currently? Perhaps we can draw up a roadmap (if one doesn't already exist).

danielhanchen commented 8 months ago

@jeromeku Oh ye a roadmap would be nice - don't actually have one for inference specifically :)

jeromeku commented 8 months ago

@danielhanchen

You mentioned integrating ideas from fastGPT and EAGLE, what others did you have in mind?

What's on the roadmap for fine-tuning / training -- architectures, algorithms, etc.? Asking so I know what literature / code to review.

danielhanchen commented 8 months ago

@jeromeku In terms of inference specifically:

  1. GPT Fast
  2. Speculative Decoding (use a small model to generate tokens, then use a large model in 1 forward pass and see if the argmax of the logits match)
  3. EAGLE (Speculative Decoding but only Word2Vec style ie lm_head -> embeddings)
  4. All quant methods - HQQ, AWQ, Exllama etc
  5. vLLM's Paged Attention
  6. Full 1 singular Triton kernel fusion - ie can we write 1 forward pass in 1 humoungous Triton kernel? Very hard since there are synchronizations which have to be done
  7. Using float8 like Fire Attention. cuDNN has float8 flash attention I think as well.
  8. Rewriting matrix vector multiplication in Triton exactly (like what you were trying to do with GPTQ but not matmul, but matvec
  9. Torch export

I might have more, but those are from the top of my head.

For training / finetuning:

  1. Fast MoE matmul kernel https://github.com/vllm-project/vllm/pull/2453 but for training - much more complex than inference on batch sizes of 1. Mixtral selects the top 2 experts, which can easily be done in Triton. However, when you have bsz>1, we have issues. One has to do dynamic compressed packing then call torch.bmm. The backward pass is even more problematic, since it requires a reversed packing then calling torch.bmm, then deconstructing it. A nightmare.
  2. Galore - extremely fascinating projecting gradients to a small (rank, rank) matrix, then using SVD to update the projectors. It's not Galore that I was fascinatined by, but rather Lomo, which does gradient updates dynamically, and this can save 20GB of VRAM during pretraining.
  3. 1.58bit - I recently wrote on HN about how 1.58bit allows one to not to multiplications since (-1, 0, 1) becomes a simple sign flip then the mantissas are added after the exponents are flipped. Using 8bit floats, 1.58bit uses 2x less space than float8, which makes it possible to cram 2x transistors. Writing it in Triton can be more complex.

Just a brain dump!

jeromeku commented 8 months ago

@danielhanchen

Love it.

Inference:

Training:

danielhanchen commented 8 months ago
jeromeku commented 8 months ago

Let me know what I should prioritize.

Also, can you expand more on Triton GEMV? What kind of horizontal / vertical fusions to target?

danielhanchen commented 8 months ago

Oh so GEMV is generally OK I guess - the issue is the dequant step merged in (ie what you were doing with GPTQ, except its not matrix matrix mult but matrix vector mult) this allows different optimizations - ie is blocked mm better or is column or is row wise mv better? It depends on the cache footprint

But the goal is can we somehow merge X @ Wq, X @ Wk, X @ Wv together with RoPE and attention and everything into 1 large kernel

jeromeku commented 8 months ago

If I understand correctly:

Can you point me to the current GEMV implementation? Need a minimal implementation / testbed for benchmarking purposes.

danielhanchen commented 8 months ago

Oh for inference, you method of fusing the dequant step inside the kernel is actually ideal! For training its not, since CUBLAS is relatively smart in data movements.

An ideal kernel for GEMV ie vector * matrix kernel normally is done via: image

However a more optimal procedure is to split the reductions into 4 blocks by using atomic_add. It in fact can be say reduction columns of 4, but say blocks of 24, and cycling using the modulus function. image

A final reduction will need to be made at the end.

The current GEMV implementation will be probably the one in Fast-GPT although I haven't inspected it myself yet.

The hardest is the folding in of Bitsandbytes int4 which is a nightmare, since the blocksize is lopsided ie not whole integer multiple, which is a nightmare for cache optimality.

danielhanchen commented 8 months ago

Another approach people do is row wise image

which again can be done in parallel with a reduction as i described above

jeromeku commented 8 months ago

@danielhanchen Ok - so I'm clear on objectives:

nivibilla commented 8 months ago

For training / finetuning:

@danielhanchen Obligatory request for Multi GPU XD

danielhanchen commented 7 months ago

@jeromeku Extremely sorry on the delay - yep sounds right! :) @nivibilla Yep!

jeromeku commented 7 months ago

@danielhanchen

Is the issue with the existing bitsandbytes gemv the fact that it's CUDA only?

danielhanchen commented 7 months ago

@jeromeku Yes that can be one of the main issues - the other is folding it inside other kernels ie say 1 singular kernel can become too complex to do.

The main issue I still see with 1 kernel, so maybe I'm overthinking, is every new op requires synchronization, so maybe we should rather rely on torch.compile with CUDAGraphs to reduce the CPU overhead in between.

jeromeku commented 7 months ago

I'd imagine there is an optimization spectrum:

Will make a quick pass at implementing bnb dequant gemv in triton to see how performance compares.

Cutlass also enables some flexibility with bespoke gemm and fusions but is again cuda only. Let me know if this is of interest.

danielhanchen commented 7 months ago

@jeromeku Oh ye let's try be device agnostic :)) compile is OK, but I guess handwritting is best :) We then can use CUDAGraphs manually

jeromeku commented 7 months ago

@danielhanchen

A few updates:

danielhanchen commented 7 months ago

@jeromeku Fantastic work as always!! very very cool on fusing Adam and Galore!! Love this!

Oh on Mixtral - https://github.com/shawntan/scattermoe/tree/main/scattermoe :) Was reading up on this as well :)

On BnB dequant - I'll have a look first at it :) But you're more than happy to do it if you want :)

jeromeku commented 7 months ago

@danielhanchen

pHaeusler commented 7 months ago

Really excited about optimized kernels for inference!

Worth looking at https://github.com/zeux/calm - where the forward pass is implemented as a single cuda kernel

Uses fp8 rather than int4/8 quantization.