Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
We currently "fudge" autograd.Function by running through the forward as if it was the function and rely on the differentiation of that to work.
(This is not good when there is .detach() or some such in the forward.)
Given that
we do know the call args (in the sense of "meta") of the backward once we have traced through the forward,
we do not yet know the return signature (because we might have None or a Tensor of the same shape as the matching input for the tensor inputs).
I wonder if we could, after tracing the forward
make a non-prim symbol with the forward trace as subsymbols,
inspect the ctx,
trace through the backward at that time, put that in a non-primitive symbol as well and make a gradient rule mapping the two.
This will be a considerable amount of magic, unfortunately, but it might just work.
We currently "fudge" autograd.Function by running through the forward as if it was the function and rely on the differentiation of that to work. (This is not good when there is
.detach()
or some such in the forward.)Given that
None
or a Tensor of the same shape as the matching input for the tensor inputs).I wonder if we could, after tracing the forward
ctx
,This will be a considerable amount of magic, unfortunately, but it might just work.