Open vkuzo opened 1 month ago
Update:
MoE model sends tensors with varying dims to the individual experts, so the failue is expected in the absence of padding
We can enable padding with the following update to the float8 config:
config = Float8LinearConfig(pad_inner_dim=True)
convert_to_float8_training(m, module_filter_fn=module_filter_fn, config=config)
Full example: https://gist.github.com/vkuzo/f5cd488ab635ea3dfe1205aa68eca473
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
After installing the triton nightly and enabling float8 padding, we can e2e compile this example.
Keeping this issue open to track documenting this in our README.md file.
We should also add performance benchmarks on float8 + compile on MoE, as padding will have an additional overhead
Thank you very much! @vkuzo Would be awesome to have benchmarking comparison on this. And just to confirm, is this for FSDP1 or FSDP2 version?
I wrote this issue to be independent from choice of FSDP version and the localized repro does not use FSDP at all, I really hope that it's orthogonal :) If there are any issues when composing this with FSDP1 or 2, happy to help figure those out.
an additional thing: in my original repro for this issue, things worked in eager mode but not in compile, for forward pass only. For fwd+bwd, both eager and compile were broken. I suspect that torch._scaled_mm
's meta function might have shape constraints which are too stringent compared to the actual implementation of the scaled gemm, which would explain eager working but compile not working for the forward. Noting it here so I don't forget, but we should ensure eager matches compile here. TODO file an issue / fix this in pytorch/pytorch.
oh @vkuzo this errors in eager mode for me too actually, if I just uncomment the lines that run the backward too:
res = final_hidden_states.sum() + router_logits.sum()
res.backward()
What's going on here is that:
(1) in eager mode, you won't actually see the shape error until your execute your backward
(2) under compile, we generate a backward graph ahead of time when we see that your forward outputs require gradients
(3) that means that instead of getting a delayed error at the time your run the backward, you'll get the error earlier (at compile time) when we trace out the backward graph
If you really want to just run the forward and not the backward under compile, you can run under no_grad
(which will be more efficient anyway as we don't need to save activations in the generated compile graph.
So it seems like padding might be required here for both eager and compile?
I see, thanks for the explanation @bdhirsh !
🐛 Describe the bug
Specifically, if we try to compile a float8 version of a FFN expert (
MixtralBlockSparseTop2MLP
), we see shape errorsScript (requires torchao and transformers):
Output:
Full output: https://gist.github.com/vkuzo/b5136f21302cd2a259cbb37cda1aa717
Versions
cc @ezyang @chauhang @penguinwu