Closed weifengpy closed 4 months ago
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@weifengpy merged this pull request in pytorch-labs/float8_experimental@7f0d6bbb531d5d76d27d80c9ec3c7eca61de5dfa.
TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block).
there is a mismatch in float8 scale so dynamo guards assersion failed
torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())
cast_to_float8_e4m3_dynamic
(code). scale is a scalar tensor, egtensor(4674.8633)
precompute_float8_dynamic_scale
, but scale is NOT a scalar tensor, egtensor([[4674.8633]]
.squeeze
to make sure scales are always scalar tensors, and dynamo guards assersion always hold trueadded unit test so we can catch the isssue at PR time
TODO: add fp8 + torch.compile to CI in torchtitan