pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
211 stars 20 forks source link

add unit tests for FSDP2 + torch.compile(transformer block) #321

Closed weifengpy closed 4 months ago

weifengpy commented 4 months ago

TorchTitan complains about FSDP2 + float8 + torch.compile(transformer block).

there is a mismatch in float8 scale so dynamo guards assersion failed torch._C._dynamo.guards.assert_size_stride(new_inputs[3], (), ())

added unit test so we can catch the isssue at PR time

TODO: add fp8 + torch.compile to CI in torchtitan

facebook-github-bot commented 4 months ago

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 4 months ago

@weifengpy merged this pull request in pytorch-labs/float8_experimental@7f0d6bbb531d5d76d27d80c9ec3c7eca61de5dfa.