Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

Transforms as EarlyTransforms #640

Open t-vi opened 4 days ago

t-vi commented 4 days ago

Currently we have grad and noop as AdditionalTransforms, however we might want to move them to early transforms to be able to work with epilogue-carrying functions.

Also, we have vjp, vmap and others that currently do now work well with jitted models and for which we should also do something (i.e. figure out what it means to change the input signature), what do we want to do when a module is involved? One option could be to just work on the arguments and ignore the model in the background. (If we want to support that at all.)

My idea would be to adapt these to work on the traces directly rather than the current indirection.

I'm filing this issue to discuss this, and then we see how to move things forward.

mruberry commented 4 days ago

triage review: the design tag marks issues where we want people to review and propose a design (possibly with a linked PR to see a sample implementation), so look for those and let's try to call them out