Open gau-nernst opened 1 day ago
Had this idea and discussed briefly with @andrewor14.
Conceptually the current QAT + FSDP looks like this
However, we can do low-bit all-gather, since the weight can be quantized before all-gather
In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)
This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.
This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.
In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)
This would chain nicely with also doing the matrix multiply in low precision.
Had this idea and discussed briefly with @andrewor14.
Conceptually the current QAT + FSDP looks like this
However, we can do low-bit all-gather, since the weight can be quantized before all-gather
In terms of perf, basically we are comparing between (ignoring potential fusion surrounding this)
This might be a small perf win, especially when distributed comm is bottleneck. Might be useful for QAT recipes in torchtune.
This is probably a low priority, so just leave it here if anyone is interested to implement. Need to quantify the speedup, if any.
In terms of implementation, we can follow float8 design (https://github.com/pytorch/ao/blob/000a49026459dd1dadf5ca34322d98e7b1680250/torchao/float8/fsdp_utils.py)