NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.61k stars 256 forks source link

[PyTorch] Add CUDA graph tests with FP8 weight caching #869

Closed timmoon10 closed 1 month ago

timmoon10 commented 1 month ago

Description

Our CUDA graph infrastructure (https://github.com/NVIDIA/TransformerEngine/pull/575) supports FP8 weight caching when training with multiple gradient accumulation steps. While adding tests for this functionality (see https://github.com/NVIDIA/TransformerEngine/pull/820#issuecomment-2127600173), I ran into subtle correctness issues because te.make_graphed_callables resets grad buffers after capturing graphs, so the gradient buffer filled in the backward pass is different from the gradient buffer used by the optimizer. Note that we didn't detect this before because Megatron-LM and Nemo explicitly manage gradient buffers, e.g. with a distributed optimizer.

This PR modifies the CUDA graph tests to initialize grads before make_graphed_callables and it avoids resetting grads within make_graphed_callables.

Type of change

Changes

Please list the changes introduced in this PR:

Checklist:

timmoon10 commented 1 month ago

Note that users must also initialize grads before calling torch.cuda.make_graphed_callables, or else they'll run into similar correctness issues. Consider the following example:

import torch
torch.set_default_device('cuda')

# Construct linear module
model = torch.nn.Linear(1, 1, bias=False)
with torch.no_grad():
    model.weight.fill_(1)
    # model.weight.grad = torch.empty_like(model.weight)  # Uncomment to fix bug

# Capture CUDA graph
x = torch.ones((1, 1), requires_grad=True)
model = torch.cuda.make_graphed_callables(model, (x,))

# Training steps
for step in range(3):
    if model.weight.grad is not None:
        model.weight.grad.zero_()
    x = torch.ones((1, 1), requires_grad=True)
    y = model(x)
    y.backward(torch.ones((1, 1)))
    print(f"{step=}, {model.weight.grad.item()=}")

I expect the weight gradient to always be 1. However:

step=0, model.weight.grad.item()=1.0
step=1, model.weight.grad.item()=2.0
step=2, model.weight.grad.item()=2.0
timmoon10 commented 1 month ago

/te-ci pytorch