pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
772 stars 97 forks source link

we should ensure activation checkpointing with Float8Linear behaves optimally #893

Open vkuzo opened 5 days ago

vkuzo commented 5 days ago

When AC is on for Float8Linear, what I would expect is:

  1. the forward gemm is recomputed in the backward (it is not being recomputed now)
  2. max(abs(activation)) and max(abs(weight)) are NOT recomputed, it's much better to always reuse them as they are tiny (seems like one of these is being recomputed now)

Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require https://github.com/pytorch/ao/pull/892

bfloat16 linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter bfloat16 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 2 50 54 PM

we see 1 gemm in the forward and 3 in the backward, as expected

Float8Linear fwd/bwd with activation checkpointing on

repro command

python benchmarks/float8/profile_linear_float8.py ~/local/tmp/20240916_act_chk_on --dtype_filter float8 --enable_activation_checkpointing True

trace snippet

Screenshot 2024-09-16 at 3 05 37 PM

issue 1: there are only two gemms in the backward instead of three issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))

vkuzo commented 4 days ago

the torch._scaled_mm behavior seems fine the max(abs(tensor)) behavior seems inoptimal and we can do better with custom AC settings. I wrote up https://github.com/pytorch/torchtitan/pull/580 with initial findings, will follow up after the conferences this week with more.