NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
Apache License 2.0
1.61k stars 256 forks source link

Calling backward(retain_graph=True) multiple times with TE Layer does not work #990

Open kshitij12345 opened 4 days ago

kshitij12345 commented 4 days ago
import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast

# m = torch.nn.Linear(16, 16).to("cuda")  # This works
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')

with fp8_autocast(True):
    o = m(x).sum()


# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.

This would be useful to support benchmarking just the backward pass.

timmoon10 commented 3 days ago

In the module-based API, we deliberately clear the intermediate tensors to minimize memory usage (see https://github.com/NVIDIA/TransformerEngine/pull/509). For example, in _Linear.backward: https://github.com/NVIDIA/TransformerEngine/blob/56e0b351d0b7db6e622d3aa3eac6c6a1bf1ce4ab/transformer_engine/pytorch/module/linear.py#L587-L588 We should mark these autograd backwards with once_differentiable for clarity.

Note that the functional API in https://github.com/NVIDIA/TransformerEngine/pull/707 does not manually deallocate intermediate tensors (the operators themselves do deallocate to achieve the same memory savings as the module-based API). Once that is merged, it should be better for this use-case.