Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.14k stars 73 forks source link

Reentrant JIT for higher order operators #1134

Open IvanYashchuk opened 1 week ago

IvanYashchuk commented 1 week ago

🚀 Feature

Add support for PyTorch Callable -> Thunder Callable translation in Thunder JIT.

Motivation

Several PyTorch operators accept a Python function with PyTorch operations inside as one of their arguments. In PyTorch, these operators are called "higher order operators". Examples of these operators:

Pitch

Thunder should support all of the above operators. It's easy to support only Thunder functions as inputs (example for checkpoint https://github.com/Lightning-AI/lightning-thunder/pull/1127), but the best user experience would be enabled by the automatic translation of user-provided PyTorch callables into Thunder ones while constructing the initial Thunder trace.

An example of torch.cond to support:

import torch

def true_fn(x: torch.Tensor):
    return torch.cos(x)
def false_fn(x: torch.Tensor):
    return torch.sin(x)

# Ideally putting thunder.jit decorator should just work, this requires translation of true_fn and false_fn into Thunder functions so that the insides could be traced and understood by the rest of the system
# @thunder.jit
def f(true_fn, false_fn, x):
    return torch.cond(x.shape[0] > 4, true_fn, false_fn, (x,))

x = torch.ones(5)
print(f(true_fn, false_fn, x))

Alternatives

@t-vi, please fill in this section with details about alternative solutions.

Implement checkpointing via a lookaside that

The other higher order functions are prototypes currently, barring other pressing needs I think this should inform our prioritization. It would be a formidable change to the nature of traces to have higher order functions in them. Before looking at this usecase, it would be good to figure out "call jitted function / module from jitted function" first, I guess this would be very useful for jitting training loops with optimizer steps.

Additional context

An attempt at using jit inside lookasides currently fails: https://github.com/Lightning-AI/lightning-thunder/issues/1126.

lantiga commented 1 week ago

Given that all the functions torch.utils.checkpoint.checkpoint are marked as

    .. warning::
        `torch.associative_scan` is a prototype feature in PyTorch. It currently
        does not support autograd and you may run into miscompiles.
        Read more about feature classification at:
        https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

we can probably move higher-order functions to a longer-term discussion as support for those mature in PyTorch.

torch.utils.checkpoint.checkpoint looks to me more like a case of "wrapping" rather than one of a general higher-order function.

lantiga commented 1 week ago

Specifically I'd like to understand the implications of having higher order functions in traces, from the point of view of transform authors, and ensuring everything keeps working with everything.