pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.43k stars 142 forks source link

Quantized Training #554

Open msaroufim opened 2 months ago

msaroufim commented 2 months ago

Inspired by a recent back and forth with @gau-nernst we should add some quantized training recipes in AO for small models (600M param range)

Character.ai recently shared that they're working on quantized training https://research.character.ai/optimizing-inference/ where per @stephenroller they train models from scratch in int8 https://x.com/stephenroller/status/1816636257717436779

Historically we've invested more in QAT which @andrewor14 has led which is more of a technique to reduce perplexity when we do an eventual post training quantization.

Quantized training on the other hand actually quantizes the model at training time and so memory savings are observed both for training and inference

So when discussing quantized training there's a few aspects

  1. Weights they can be in one: fp16, fp8, int8, int4 and below
  2. Activations most likely limited to fp8, fp16
  3. Optimizer can be in one of: fp32, fp16, bf16, fp8, int8 and below

And if one were to ship this work, a bad combination can be validated at small scale (~600M parameter range) but a good idea needs to continuously be tested from (8b to 405b range) so each of these will need loss curves

When choosing the starting point, we could either pretrain a model using quantized training or just finetune it and as long as the loss curves match the fp16 baselines then we are good. We'd also need to of course validate that memory savings are there and what the speedups/slowdowns are.

And while we can merge a lot of the dtype conversion in AO and have some toy training loop in AO what I'm more optimistic about is having some end to end trainig recipe in https://github.com/pytorch/torchtitan @awgu and an end to end finetuning recipe https://github.com/pytorch/torchtune @ebsmothers @joecummings

gau-nernst commented 2 months ago

Just want to add, there is also activation/computation dtype and gradient dtype. In my exploration, I still use activation/computation in BF16 and gradient in BF16 to match weight-only quant inference. Activation/computation can be in lower precision dtype also, such as INT8 act - INT8 weight to match dynamic quant inference, or FP8 act - FP8 weight to match current FP8 training recipe.

Lower precision gradient might not be possible? Will need to check existing works on this.

gau-nernst commented 2 months ago

Some extra info for future reference

For evaluating the effectiveness of quantized training

Digging into AQT INT8 (update as I read more). Many things can be customized, but the basic config is:

gau-nernst commented 2 months ago

Found this interesting paper - Jetfire. ICML 2024 poster spotlight. With code release and custom CUDA kernels :open_mouth:

INT8 for everything, including activations and gradients. Tile-wise quantization. Also use 127 for scaling. Master weight in FP32.

Which also led to me an earlier paper - SwitchBack. Timm Dettmers is one of the authors :laughing:.

Dynamic quantization for everything (weight is still in high precision). Row-wise quant (i.e. batch dim - per token) for activation (forward) and grad output (backward). Tensor-wise quant for weight. INT8 matmul for forward (Y = X @ W.T) and input grad backward (X_grad = Y_grad @ W), while weight grad is FP16 matmul (W_grad = Y_grad.T @ X)

jerryzh168 commented 2 months ago

@gau-nernst thanks for the pointers, feels like these are good motivations to enable training with AffineQuantizedTensor since it will be general to support all kinds of quantization (per block, row-wise, per token) and both for dynamic quant and weight only quant. cc @andrewor14