pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

delayed scaling: stop syncing weight amax values across ranks #277

Closed vkuzo closed 1 month ago

vkuzo commented 1 month ago

Stack from ghstack (oldest at bottom):

Summary:

FSDP already ensures that each rank receives the same weight, so the amaxes of weights are the same on each rank.

I checked performance before/after on the multi GPU benchmark and didn't see a significant impact on the toy model, but less comms value is better.

Test Plan:

./test_everything.sh passes

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D58396925

vkuzo commented 1 month ago

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 1 month ago

This pull request has been merged in pytorch-labs/float8_experimental@5d293a723d8da1fb3dbaa63522b091cd2e3f3146.