Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.19k stars 80 forks source link

Implement gelu (and other elementwise fusions) recomputation during backward #1012

Open IvanYashchuk opened 2 months ago

IvanYashchuk commented 2 months ago

🚀 Feature

Motivation

import torch
import thunder

def f(x):
    x = x @ x
    x = torch.nn.functional.gelu(x, approximate="none")
    x = x @ x
    return x

x = torch.randn(1000, 1000, device="cuda", requires_grad=True)
jf = thunder.jit(f)
y = jf(x)

print([t.shape for t in y.grad_fn.saved_tensors])
# This will print:
# [torch.Size([1000, 1000]), torch.Size([1000, 1000]), torch.Size([1000, 1000])]

print(thunder.core.vjp_utils.get_saved_for_backward_tensors(thunder.last_traces(jf)[-1]))
# This will print:
# (<TensorProxy(name="t0", dtype=thunder.dtypes.float32, shape=(1000, 1000))>,
#  <TensorProxy(name="t5", dtype=thunder.dtypes.float32, shape=(1000, 1000))>,
#  <TensorProxy(name="x", dtype=thunder.dtypes.float32, shape=(1000, 1000))>)

In the above snippet t5 is the output of the gelu function and the request is to implement a pass that forces recomputation of the gelu function in the backward pass instead of saving this intermediate tensor.

Ongoing PR: https://github.com/Lightning-AI/lightning-thunder/pull/1003.

Implementing gelu recomputation would resolve the OOM error seen in https://github.com/Lightning-AI/lightning-thunder/issues/246.

riccardofelluga commented 2 months ago

Sounds resonable! More context on this here

riccardofelluga commented 2 months ago

Another potential fix for this might be the use of Liger kernels. I think that's worth trying but it requires creating a quick and dirty executor. Let's proceed with the remat for now and later I'll come back to try this out