This PR ensures each of the 3 gemms in fw+bw of the float8 linear can have its own configuration.
The user interface is clean: the float8 module has separate configs for each of the gemms. Configuring this can be exposed in module utilites in a future PR.
The implementation is a bit ugly. given the 3 gemms in fw+bw
Y = X @ W_t
gradX = gradY @ W
gradW = X_t @ gradY
and the fact that a torch.mm doesn't have access to any global state other than its arguments, we need to ensure that the matmul arguments contain the right state to map to the right config. We adopt a simple set of rules to do this:
if only one of the arguments to mm has a config, use the defined config
if both arguments of mm have a config, use the second argument's config
in the float8 modules, do the following:
3a. set X's config for arg0 to config_Y
3b. set W's config for arg0 and arg1 to None
3c. set gradY's config for arg0 to config_gradX, and for arg1 to
config_gradW
If 3 is done correctly, following 1 and 2 will lead to the right config being used for each gemm. It's ugly, but it works.
Summary:
This PR ensures each of the 3 gemms in fw+bw of the float8 linear can have its own configuration.
The user interface is clean: the float8 module has separate configs for each of the gemms. Configuring this can be exposed in module utilites in a future PR.
The implementation is a bit ugly. given the 3 gemms in fw+bw
and the fact that a
torch.mm
doesn't have access to any global state other than its arguments, we need to ensure that the matmul arguments contain the right state to map to the right config. We adopt a simple set of rules to do this:If 3 is done correctly, following 1 and 2 will lead to the right config being used for each gemm. It's ugly, but it works.
Test Plan:
for now, works on single GPU
just a question of eng time to also make the distributed tests pass
Reviewers:
Subscribers:
Tasks:
Tags: