Closed vkuzo closed 1 month ago
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
This pull request has been merged in pytorch-labs/float8_experimental@5d293a723d8da1fb3dbaa63522b091cd2e3f3146.
Stack from ghstack (oldest at bottom):
278
276
Summary:
FSDP already ensures that each rank receives the same weight, so the amaxes of weights are the same on each rank.
I checked performance before/after on the multi GPU benchmark and didn't see a significant impact on the toy model, but less comms value is better.
Test Plan:
./test_everything.sh passes
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D58396925