Open kshitij12345 opened 6 days ago
One of the reason for the issue could that we have our own version of Context which may probably behaves differently when backward is called with retain_grad
.
I think this may be the famous clearing of collections...
PyTorch requires all tensor data to be saved with ctx.saved_for_backward(*tensors)
when it's not done objects can be deleted preemptively.
Even if we fix this, TransformerEngine itself has problem with running the backward multiple times with retain_graph=True
.
import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')
with fp8_autocast(True):
o = m(x).sum()
o.backward(retain_graph=True)
# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward()
Let's do the work anyway so that we hit the same error as in pure TE usage.
Currently, we send the data saved backward from forward using a mock context object using a dictionary
https://github.com/Lightning-AI/lightning-thunder/blob/ab514fcb7836683c7b57f9bd7fe338a77d516bd2/thunder/executors/transformer_engineex.py#L125-L140
A dictionary instead of a Context
object directly was used to let clear_mutable_collections
clear the dict and subsequently release the memory held by tensors. It was also considered too much work to faithfully model the properties of saved tensors (number of saved tensors their type and shape). However, to ensure correct behavior from PyTorch we need to pass tensor objects through PyTorch's save_for_backward(*tensors)
function. That means we need to modify TransformerEngine's meta to return TensorProxies explicitly and not hidden behind a generic 'CollectionProxy'.
cc: @IvanYashchuk