Open vkuzo opened 5 days ago
the torch._scaled_mm
behavior seems fine
the max(abs(tensor))
behavior seems inoptimal and we can do better with custom AC settings. I wrote up https://github.com/pytorch/torchtitan/pull/580 with initial findings, will follow up after the conferences this week with more.
When AC is on for Float8Linear, what I would expect is:
Let's figure out why this isn't what is happening now and what we should do about it. Note: reproductions below require https://github.com/pytorch/ao/pull/892
bfloat16 linear fwd/bwd with activation checkpointing on
repro command
trace snippet
we see 1 gemm in the forward and 3 in the backward, as expected
Float8Linear fwd/bwd with activation checkpointing on
repro command
trace snippet
issue 1: there are only two gemms in the backward instead of three issue 2: there are some extra kernels in the backward which are recomputing max(abs(activation)) and max(abs(weight))