pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.6k stars 179 forks source link

float8 with delayed scaling: fix autocast handling #1306

Closed vkuzo closed 6 days ago

vkuzo commented 1 week ago

Summary:

Fixes a bug with delayed scaling + autocast.

Before, the last input dtype when in autocast was queried from the input to torch._scaled_mm:

x_hp -> {query_dtype_here} -> to_autocast -> torch._scaled_mm

This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if x_hp was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from https://github.com/pytorch/ao/issues/1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast.

The fix I'm taking here is to query the original post-autocast dtype based on the output of torch._scaled_mm. Since this dtype is based on the dtype of the input to torch._scaled_mm, this will properly capture autocasting:

x_hp -> to_autocast -> x_autocast_dtype -> to_fp8 -> x_fp8 -> torch._scaled_mm -> {query_dtype_here}

Test Plan:

// first, test the updated test case - it passes

// second - test a modified version of the repro in
// https://github.com/pytorch/ao/issues/1297:
// code: https://gist.github.com/vkuzo/6c53a1deca19856238d38746b1e52ee7
// logs: https://gist.github.com/vkuzo/60846b1f6b2822f91d2dfa67cab10a10
// we now see a speedup with float8

Reviewers:

Subscribers:

Tasks:

Tags:

pytorch-bot[bot] commented 1 week ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1306

Note: Links to docs will display an error until the docs builds have been completed.

:heavy_exclamation_mark: 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

:white_check_mark: No Failures

As of commit 4402195e28d7bb940dc8448d0ae8c791a47c25bd with merge base 6234116a65a4c24d21fb7f5afd501a786cc474a8 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.