Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k
stars
77
forks
source link
enable using python_callable without mapping symbols to their impls #1168
Currently, we can't use python_callable if we want to retrace because it does half of transform for execution.
We should absolutely get a flag on python_callable for that because we have a workaround in Thunder quite often now (in thunder/core/transforms.py but also in the autograd.Function handling in jit_ext now).
(the ["output"] might not make this example trivial, but still).
This is another, more prototypical, example the commented line does not work, but it should with a flag (or maybe even by default, so only transform for execution does it differently).
Currently, we can't use python_callable if we want to retrace because it does half of transform for execution. We should absolutely get a flag on python_callable for that because we have a workaround in Thunder quite often now (in thunder/core/transforms.py but also in the autograd.Function handling in jit_ext now). (the ["output"] might not make this example trivial, but still).
https://github.com/Lightning-AI/lightning-thunder/blob/ea96657689ab3c60a41897fcd0be5d00e685a449/thunder/core/transforms.py#L1438-L1445
cc @apaz-cli