Closed weifengpy closed 2 months ago
nice! Can we include the intended user API in the PR summary?
nice! Can we include the intended user API in the PR summary?
added example API usage in PR summary
the linter error comes from trunk. I opened another PR to fix them: https://github.com/pytorch-labs/float8_experimental/pull/313
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@weifengpy has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@weifengpy merged this pull request in pytorch-labs/float8_experimental@6cba2aeade7f2500d7b32c8e38106847201d7feb.
Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce
updated README for API usage: call
precompute_float8_scale_for_fsdp
inside the training loop after optimizer stepunit test
pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic
FSDP pre-forward: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces
Pre-computing amax: shortened from 5ms to 1.7ms, by switching from
torch._foreach_abs
+torch.max(a)
totorch._foreach_norm(weights, ord=math.inf)