Open gau-nernst opened 1 month ago
cc: @weifengpy
@gau-nernst nice work! I took a look at the original torchao PR
@weifengpy
I am mostly curious about gaps for checkpointing if any
What do you mean by this?
Yea this ("rowwise became column wise in the backward") is the main problem preventing me from implementing INT8 all-gather.
Some extra thoughts.
thanks for explaining everything in detail
What do you mean by this?
I thought model.parameters()
might be some INT8 related tensor subclass
Right now all-gather is still in BF16, so quantization will be done after all-gather
yeah. INT8 all-gather might be the main justfication to land into torchtitan, since this repo is used to demonstrate composability with distributed api
for rowwise, if backward is too hard, are you comfortable with supporting INT8 all-gather with fully_shard(reshard_after_forward=False)
? In that case, we do not have all-gather in the backward
For pre-training, it might be possible to do INT8 tensor-wise scaling
if the numerics does not become too bad witht tensor-wise scaling, it's a great demonstration for INT8 all-gather
I thought model.parameters() might be some INT8 related tensor subclass
Oh yea right now I don't have any special logic with it. So that state_dict will be a tensor subclass wrapper holding the original high precision weight (NOT int8). For INT8 mixed-precision training, I only inject custom matmul logic, weights stay the same (same as FP8 training).
supporting INT8 all-gather with fully_shard(reshard_after_forward=False) ? In that case, we do not have all-gather in the backward
Does that mean INT8 post-all-gathered weights remain in memory starting from forward until backward? If that's the case, we can just do what I suggested earlier?
use row-wise quantized weight (from forward) for column-wise scaling in backward (i.e. dequant and re-quant)
More concretely:
Pass | Original | Suggested change |
---|---|---|
Forward | FP32 weight -> all-gather -> row-wise quantize to INT8 | FP32 weight -> row-wise quantize to INT8 -> all-gather |
Backward | FP32 weight -> all-gather -> column-wise quantize to INT8 | FP32 weight -> row-wise quantize to INT8 -> all-gather -> dequant -> column-wise quantize to INT8 |
In other words, it differs in which version of the weight will be used for column-wise quantization in backward: whether to use the original weight, or use the row-wise quantized weight used in forward.
Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling 🤣.
Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling
agree, having tensor-wise scaling is already a good thing. I will bring this topic for discussion with the team
I think long term it's great to unify training APIs in torchao, to enable torchtitan to work with float8/int8/mx/etc training in the same way. I'm working on this, no ETA yet.
Short term if someone wants to add int8 training to torchtitan as an experimental feature - SGTM personally, but I'll also defer to torchtitan folks on that.
I will bring this topic for discussion with the team
we would love to have this feature after discussion. we can start with tensor-wise scaling. it's also consistent with our float8 offering
@mori360
Recently I worked on INT8 mixed-precision training in torchao. The relevant PR is here https://github.com/pytorch/ao/pull/748
Preliminary results show that with torchtitan, it improves speed by 20% on 8x A100 with no noticeable difference in loss curve. See the PR for more details.
Would you be open to add an experimental flag for this in torchtitan? Similar to Float8 training. This can also help to profile and improve INT8 training performance directly in torchtitan for future perf optimization.
cc @msaroufim