Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.08k stars 62 forks source link

[TransformerEngine] Support `backward(retain_graph=True)` #701

Open kshitij12345 opened 6 days ago

kshitij12345 commented 6 days ago
from thunder.benchmarks.targets import LitGPTConfig, LitGPTBenchmark, backward_only
import torch
import thunder
from thunder.executors.transformer_engineex import transformer_engine_ex

# With bigger model
# cfg: LitGPTConfig = LitGPTConfig.from_name("Llama-2-7b-hf")
# cfg.n_layer = 3
# b = LitGPTBenchmark(cfg, batchdims=(2,), device="cuda:0", dtype=torch.bfloat16, requires_grad=True)
# args, kwargs = b.make_batch()
# fn = thunder.jit(b.fn(), executors=[transformer_engine_ex,])

# With smaller model
def foo(x, w):
    return torch.nn.functional.linear(x, w)
x = torch.randn(16, 16, requires_grad=True, device='cuda')
w = torch.randn(16, 16, requires_grad=True, device='cuda')
args = (x, w)
kwargs = {}
fn = thunder.jit(foo, executors=[transformer_engine_ex,])

# backward_only creates a graph and calls `torch.autograd.backward` with `retain_graph=True`.
backward_fn, backward_setup = backward_only(fn, *args, **kwargs)
backward_args = backward_setup()
backward_fn(*backward_args)

# Second usage of `backward_fn` fails with 
# File "/home/kkalambarkar/lightning-thunder/thunder/executors/transformer_engineex.py", line 352, in _te_functional_linear_backward_impl
#     with enable_grad(ctx.saved_tensors[2]):
# IndexError: tuple index out of range
backward_fn(*backward_args)

cc: @IvanYashchuk

kshitij12345 commented 6 days ago

One of the reason for the issue could that we have our own version of Context which may probably behaves differently when backward is called with retain_grad.

t-vi commented 6 days ago

I think this may be the famous clearing of collections...

IvanYashchuk commented 6 days ago

PyTorch requires all tensor data to be saved with ctx.saved_for_backward(*tensors) when it's not done objects can be deleted preemptively.

kshitij12345 commented 6 days ago

Even if we fix this, TransformerEngine itself has problem with running the backward multiple times with retain_graph=True.

import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast

m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')

with fp8_autocast(True):
    o = m(x).sum()

o.backward(retain_graph=True)

# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward()
IvanYashchuk commented 6 days ago

Let's do the work anyway so that we hit the same error as in pure TE usage.

IvanYashchuk commented 5 days ago

Currently, we send the data saved backward from forward using a mock context object using a dictionary https://github.com/Lightning-AI/lightning-thunder/blob/ab514fcb7836683c7b57f9bd7fe338a77d516bd2/thunder/executors/transformer_engineex.py#L125-L140 A dictionary instead of a Context object directly was used to let clear_mutable_collections clear the dict and subsequently release the memory held by tensors. It was also considered too much work to faithfully model the properties of saved tensors (number of saved tensors their type and shape). However, to ensure correct behavior from PyTorch we need to pass tensor objects through PyTorch's save_for_backward(*tensors) function. That means we need to modify TransformerEngine's meta to return TensorProxies explicitly and not hidden behind a generic 'CollectionProxy'.