pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

[wip] make all 3 gemms in float8 linear configurable #258

Closed vkuzo closed 1 day ago

vkuzo commented 2 months ago

Summary:

This PR ensures each of the 3 gemms in fw+bw of the float8 linear can have its own configuration.

The user interface is clean: the float8 module has separate configs for each of the gemms. Configuring this can be exposed in module utilites in a future PR.

The implementation is a bit ugly. given the 3 gemms in fw+bw

Y = X @ W_t
gradX = gradY @ W
gradW = X_t @ gradY

and the fact that a torch.mm doesn't have access to any global state other than its arguments, we need to ensure that the matmul arguments contain the right state to map to the right config. We adopt a simple set of rules to do this:

  1. if only one of the arguments to mm has a config, use the defined config
  2. if both arguments of mm have a config, use the second argument's config
  3. in the float8 modules, do the following: 3a. set X's config for arg0 to config_Y 3b. set W's config for arg0 and arg1 to None 3c. set gradY's config for arg0 to config_gradX, and for arg1 to config_gradW

If 3 is done correctly, following 1 and 2 will lead to the right config being used for each gemm. It's ugly, but it works.

Test Plan:

for now, works on single GPU

pytest -s test/test_base.py
pytest -s test/test_compile.py

just a question of eng time to also make the distributed tests pass

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo commented 1 day ago

closing in favor of https://github.com/pytorch-labs/float8_experimental/pull/315