Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 77 forks source link

enable using python_callable without mapping symbols to their impls #1168

Open t-vi opened 1 week ago

t-vi commented 1 week ago

Currently, we can't use python_callable if we want to retrace because it does half of transform for execution. We should absolutely get a flag on python_callable for that because we have a workaround in Thunder quite often now (in thunder/core/transforms.py but also in the autograd.Function handling in jit_ext now). (the ["output"] might not make this example trivial, but still).

https://github.com/Lightning-AI/lightning-thunder/blob/ea96657689ab3c60a41897fcd0be5d00e685a449/thunder/core/transforms.py#L1438-L1445

cc @apaz-cli

t-vi commented 6 days ago

This is another, more prototypical, example the commented line does not work, but it should with a flag (or maybe even by default, so only transform for execution does it differently).

https://github.com/Lightning-AI/lightning-thunder/blob/9580411e1dcec34ce0226a5a688ed65c32422b8b/thunder/core/jit_ext.py#L690-L694