pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed weifengpy closed 2 months ago

weifengpy commented 4 months ago

Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce

updated README for API usage: call precompute_float8_scale_for_fsdp inside the training loop after optimizer step

from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
# inside the training loop
model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)

unit test pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic

FSDP pre-forward: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces

Screenshot 2024-05-30 at 12 38 24 AM Screenshot 2024-05-30 at 12 48 14 AM

Pre-computing amax: shortened from 5ms to 1.7ms, by switching from torch._foreach_abs + torch.max(a) to torch._foreach_norm(weights, ord=math.inf)

Screenshot 2024-05-30 at 12 50 17 AM Screenshot 2024-05-30 at 12 49 54 AM
vkuzo commented 4 months ago

nice! Can we include the intended user API in the PR summary?

weifengpy commented 2 months ago

nice! Can we include the intended user API in the PR summary?

added example API usage in PR summary

weifengpy commented 2 months ago

the linter error comes from trunk. I opened another PR to fix them: https://github.com/pytorch-labs/float8_experimental/pull/313

facebook-github-bot commented 2 months ago

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 2 months ago

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 2 months ago

@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 2 months ago

@weifengpy merged this pull request in pytorch-labs/float8_experimental@6cba2aeade7f2500d7b32c8e38106847201d7feb.