NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

Calling backward(retain_graph=True) multiple times with TE Layer does not work #990

Open kshitij12345 opened 4 days ago

kshitij12345 commented 4 days ago
import torch
from transformer_engine.pytorch import Linear as TELinear, fp8_autocast

# m = torch.nn.Linear(16, 16).to("cuda")  # This works
m = TELinear(16, 16)
x = torch.randn(16, 16, device='cuda')

with fp8_autocast(True):
    o = m(x).sum()

o.backward(retain_graph=True)

# this fails with
# AssertionError: FP8 execution requires 2D input matrices with height divisible by 8 and width divisible by 16, but got tensor with dims=[0]
# looks like TELinear.backward mutates the context object such that it is not reusable.
o.backward()

This would be useful to support benchmarking just the backward pass.

timmoon10 commented 3 days ago

In the module-based API, we deliberately clear the intermediate tensors to minimize memory usage (see https://github.com/NVIDIA/TransformerEngine/pull/509). For example, in _Linear.backward: https://github.com/NVIDIA/TransformerEngine/blob/56e0b351d0b7db6e622d3aa3eac6c6a1bf1ce4ab/transformer_engine/pytorch/module/linear.py#L587-L588 We should mark these autograd backwards with once_differentiable for clarity.

Note that the functional API in https://github.com/NVIDIA/TransformerEngine/pull/707 does not manually deallocate intermediate tensors (the operators themselves do deallocate to achieve the same memory savings as the module-based API). Once that is merged, it should be better for this use-case.