pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.5k stars 187 forks source link

Support INT8 mixed-precision training from torchao? #578

Open gau-nernst opened 1 month ago

gau-nernst commented 1 month ago

Recently I worked on INT8 mixed-precision training in torchao. The relevant PR is here https://github.com/pytorch/ao/pull/748

Preliminary results show that with torchtitan, it improves speed by 20% on 8x A100 with no noticeable difference in loss curve. See the PR for more details.

Would you be open to add an experimental flag for this in torchtitan? Similar to Float8 training. This can also help to profile and improve INT8 training performance directly in torchtitan for future perf optimization.

cc @msaroufim

tianyu-l commented 3 weeks ago

cc: @weifengpy

weifengpy commented 3 weeks ago

@gau-nernst nice work! I took a look at the original torchao PR

gau-nernst commented 3 weeks ago

@weifengpy

  1. There is an on-going PR https://github.com/pytorch/torchtune/pull/1552

I am mostly curious about gaps for checkpointing if any

What do you mean by this?

  1. Right now all-gather is still in BF16, so quantization will be done after all-gather. https://github.com/pytorch/ao/blob/ae3e7c68eae7085e13241cb3d6b39481868dd162/torchao/prototype/quantized_training/int8_mixed_precision.py#L117-L126.

Yea this ("rowwise became column wise in the backward") is the main problem preventing me from implementing INT8 all-gather.

  1. Yes, speed is the main reason. The original PR description includes the revamped README for training with INT8 (which includes INT8 weight-only quantized training i.e. only keep INT8 weight), hence memory is mentioned there.

Some extra thoughts.

weifengpy commented 2 weeks ago

thanks for explaining everything in detail

What do you mean by this?

I thought model.parameters() might be some INT8 related tensor subclass

Right now all-gather is still in BF16, so quantization will be done after all-gather

yeah. INT8 all-gather might be the main justfication to land into torchtitan, since this repo is used to demonstrate composability with distributed api

for rowwise, if backward is too hard, are you comfortable with supporting INT8 all-gather with fully_shard(reshard_after_forward=False) ? In that case, we do not have all-gather in the backward

For pre-training, it might be possible to do INT8 tensor-wise scaling

if the numerics does not become too bad witht tensor-wise scaling, it's a great demonstration for INT8 all-gather

gau-nernst commented 2 weeks ago

I thought model.parameters() might be some INT8 related tensor subclass

Oh yea right now I don't have any special logic with it. So that state_dict will be a tensor subclass wrapper holding the original high precision weight (NOT int8). For INT8 mixed-precision training, I only inject custom matmul logic, weights stay the same (same as FP8 training).

supporting INT8 all-gather with fully_shard(reshard_after_forward=False) ? In that case, we do not have all-gather in the backward

Does that mean INT8 post-all-gathered weights remain in memory starting from forward until backward? If that's the case, we can just do what I suggested earlier?

use row-wise quantized weight (from forward) for column-wise scaling in backward (i.e. dequant and re-quant)

More concretely:

Pass Original Suggested change
Forward FP32 weight -> all-gather -> row-wise quantize to INT8 FP32 weight -> row-wise quantize to INT8 -> all-gather
Backward FP32 weight -> all-gather -> column-wise quantize to INT8 FP32 weight -> row-wise quantize to INT8 -> all-gather -> dequant -> column-wise quantize to INT8

In other words, it differs in which version of the weight will be used for column-wise quantization in backward: whether to use the original weight, or use the row-wise quantized weight used in forward.

Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling 🤣.

weifengpy commented 2 weeks ago

Otherwise, to just demonstrate INT8 all-gather, I think it is easier (and save efforts) to do INT8 tensor-wise scaling

agree, having tensor-wise scaling is already a good thing. I will bring this topic for discussion with the team

vkuzo commented 2 weeks ago

I think long term it's great to unify training APIs in torchao, to enable torchtitan to work with float8/int8/mx/etc training in the same way. I'm working on this, no ETA yet.

Short term if someone wants to add int8 training to torchtitan as an experimental feature - SGTM personally, but I'll also defer to torchtitan folks on that.

weifengpy commented 2 weeks ago

I will bring this topic for discussion with the team

we would love to have this feature after discussion. we can start with tensor-wise scaling. it's also consistent with our float8 offering

@mori360