pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

delayed scaling: stop syncing weight amax values across ranks #272

Closed vkuzo closed 3 months ago

vkuzo commented 4 months ago

Stack from ghstack (oldest at bottom):

Summary:

FSDP already ensures that each rank receives the same weight, so the amaxes of weights are the same on each rank.

I checked performance before/after on the multi GPU benchmark and didn't see a significant impact on the toy model, but less comms value is better.

Test Plan:

./test_everything.sh passes

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo commented 3 months ago

recreated in https://github.com/pytorch-labs/float8_experimental/pull/277 to get around ghstack weirdness