Closed timmoon10 closed 1 month ago
Note that users must also initialize grads before calling torch.cuda.make_graphed_callables
, or else they'll run into similar correctness issues. Consider the following example:
import torch
torch.set_default_device('cuda')
# Construct linear module
model = torch.nn.Linear(1, 1, bias=False)
with torch.no_grad():
model.weight.fill_(1)
# model.weight.grad = torch.empty_like(model.weight) # Uncomment to fix bug
# Capture CUDA graph
x = torch.ones((1, 1), requires_grad=True)
model = torch.cuda.make_graphed_callables(model, (x,))
# Training steps
for step in range(3):
if model.weight.grad is not None:
model.weight.grad.zero_()
x = torch.ones((1, 1), requires_grad=True)
y = model(x)
y.backward(torch.ones((1, 1)))
print(f"{step=}, {model.weight.grad.item()=}")
I expect the weight gradient to always be 1. However:
step=0, model.weight.grad.item()=1.0
step=1, model.weight.grad.item()=2.0
step=2, model.weight.grad.item()=2.0
/te-ci pytorch
Description
Our CUDA graph infrastructure (https://github.com/NVIDIA/TransformerEngine/pull/575) supports FP8 weight caching when training with multiple gradient accumulation steps. While adding tests for this functionality (see https://github.com/NVIDIA/TransformerEngine/pull/820#issuecomment-2127600173), I ran into subtle correctness issues because
te.make_graphed_callables
resets grad buffers after capturing graphs, so the gradient buffer filled in the backward pass is different from the gradient buffer used by the optimizer. Note that we didn't detect this before because Megatron-LM and Nemo explicitly manage gradient buffers, e.g. with a distributed optimizer.This PR modifies the CUDA graph tests to initialize grads before
make_graphed_callables
and it avoids resetting grads withinmake_graphed_callables
.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: