pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.54k stars 160 forks source link

[QAT] Low-bit FSDP all-gather for QAT #1224

Open gau-nernst opened 1 day ago

gau-nernst commented 1 day ago

Had this idea and discussed briefly with @andrewor14.

Conceptually the current QAT + FSDP looks like this

However, we can do low-bit all-gather, since the weight can be quantized before all-gather

In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)

  1. BF16 all-gather + fake quantize
  2. (Real) quantize (1/NGPU) + Low-bit all-gather + Dequant

This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.

This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.

In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)

vkuzo commented 1 day ago

This would chain nicely with also doing the matrix multiply in low precision.