Closed vkuzo closed 6 days ago
Note: Links to docs will display an error until the docs builds have been completed.
There are 1 currently active SEVs. If your PR is affected, please view them below:
As of commit 4402195e28d7bb940dc8448d0ae8c791a47c25bd with merge base 6234116a65a4c24d21fb7f5afd501a786cc474a8 (): :green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Summary:
Fixes a bug with delayed scaling + autocast.
Before, the last input dtype when in autocast was queried from the input to
torch._scaled_mm
:This is incorrect because the dtype was saved from before the place where autocast could change it. This happened to work if
x_hp
was already of the correct dtype, but did not work in cases such as the new test case added in this PR, or real models such as the repro from https://github.com/pytorch/ao/issues/1297. The reason we haven't caught this for so long is we've been using FSDP's mixed precision and not single-GPU autocast.The fix I'm taking here is to query the original post-autocast dtype based on the output of
torch._scaled_mm
. Since this dtype is based on the dtype of the input totorch._scaled_mm
, this will properly capture autocasting:Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: