Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

Mixtral 8x7B network support #194

Open riccardofelluga opened 7 months ago

riccardofelluga commented 7 months ago

🚀 Feature

Mixtral 8x7B is a mixture-of-experts LLM that splits the parameters in 8 distinct groups an I would like to do both training and inference with Thunder.

Work items

Additional context

Even though examine does not signal any problem with the ops, some testing revealed that Mixtral uses torch.where(condition) signature of the torch.where function which is not supported at the moment. Moreover, the second issue I was able to identify stems from the indexing done in Mixtral forward function. At the moment, the _advanced_indexing clang operation does not take into account None as a valid index together with other tensors.

t-vi commented 7 months ago

Note that unless you rearrange the mixing over what is commonly implemented, you will have data-dependent control flow.

riccardofelluga commented 7 months ago

you will have data-dependent control flow.

Exactly! What is our current stance on data-dependent control flows?

t-vi commented 7 months ago

I don't think it's on the roadmap any time soon.

riccardofelluga commented 6 months ago

Adding #303 that might be the key to get the model supported in Thunder

cc. @IvanYashchuk

riccardofelluga commented 1 week ago

Update to this issue: Mixtral 8x7B is now supported using ThunderFX path. The issues listed above remain for the JIT code path.