Open IvanYashchuk opened 1 week ago
Given that all the functions torch.utils.checkpoint.checkpoint
are marked as
.. warning::
`torch.associative_scan` is a prototype feature in PyTorch. It currently
does not support autograd and you may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
we can probably move higher-order functions to a longer-term discussion as support for those mature in PyTorch.
torch.utils.checkpoint.checkpoint
looks to me more like a case of "wrapping" rather than one of a general higher-order function.
Specifically I'd like to understand the implications of having higher order functions in traces, from the point of view of transform authors, and ensuring everything keeps working with everything.
🚀 Feature
Add support for PyTorch Callable -> Thunder Callable translation in Thunder JIT.
Motivation
Several PyTorch operators accept a Python function with PyTorch operations inside as one of their arguments. In PyTorch, these operators are called "higher order operators". Examples of these operators:
Pitch
Thunder should support all of the above operators. It's easy to support only Thunder functions as inputs (example for
checkpoint
https://github.com/Lightning-AI/lightning-thunder/pull/1127), but the best user experience would be enabled by the automatic translation of user-provided PyTorch callables into Thunder ones while constructing the initial Thunder trace.An example of
torch.cond
to support:Alternatives
@t-vi, please fill in this section with details about alternative solutions.
Implement checkpointing via a lookaside that
rematerialize_for_backward
or so proxy tag to proxies that are wrapped (or created and wrapped),The other higher order functions are prototypes currently, barring other pressing needs I think this should inform our prioritization. It would be a formidable change to the nature of traces to have higher order functions in them. Before looking at this usecase, it would be good to figure out "call jitted function / module from jitted function" first, I guess this would be very useful for jitting training loops with optimizer steps.
Additional context
An attempt at using jit inside lookasides currently fails: https://github.com/Lightning-AI/lightning-thunder/issues/1126.