NVIDIA / TransformerEngine

A library for accelerating Transformer models on NVIDIA GPUs, including using 8-bit floating point (FP8) precision on Hopper and Ada GPUs, to provide better performance with lower memory utilization in both training and inference.
https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html
Apache License 2.0
1.78k stars 296 forks source link

checkpoint saving and resume for a long time and a huge amount of device space #231

Open lx1374327576 opened 1 year ago

lx1374327576 commented 1 year ago

When training using fp8 Linear, it needs long time and a huge amount of device space to save checkpoint for extra state. How do I solve this problem.

### Tasks
lx1374327576 commented 1 year ago

It seems that it gradually increases with the number of training steps.

ptrendx commented 1 year ago

Hi @lx1374327576, this sounds like a bug. Could you provide a simple repro script that shows this behavior?

lx1374327576 commented 1 year ago
import argparse
import torch
from transformer_engine.pytorch import Linear as teLinear
from transformer_engine.common import recipe
import transformer_engine.pytorch as te

parser = argparse.ArgumentParser()
parser.add_argument("--steps", default=1000, type=int)
args = parser.parse_args()
input = torch.randn((64, 768)).cuda()
model = teLinear(768, 768, bias=False).cuda()
optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=0.00001)
loss_fn = torch.nn.MSELoss()

fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    for i in range(args.steps):
        output = model(input)
        loss = torch.sum(loss_fn(output, input))
        loss.backward()
        optimizer.step()
torch.save(model.state_dict(), 'model.pt')
for k, v in torch.load('model.pt').items():
    print(k, v.shape)
lx1374327576 commented 1 year ago

@ptrendx Here is command and log. python3 test_ckpt.py --steps 1000 weight torch.Size([768, 768]) _extra_state torch.Size([285645])

python3 test_ckpt.py --steps 10000 weight torch.Size([768, 768]) _extra_state torch.Size([2842014])

python3 test_ckpt.py --steps 100000 weight torch.Size([768, 768]) _extra_state torch.Size([28405695])

ptrendx commented 1 year ago

The problem comes from the main loop - the fp8_autocast should be inside the step loop, not outside (also backward call should be outside fp8_autocast region as well - it inherits the FP8 execution from the forward pass):

for i in range(args.steps):
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        output = model(input)
        loss = torch.sum(loss_fn(output, input))
    loss.backward()
    optimizer.step()

With that change and the new TE (I assume you are using an older version as we actually added the check and an error message for this script issue) I get:

# python test.py --steps 1000
weight torch.Size([768, 768])
_extra_state torch.Size([22675])

# python test.py --steps 10000
weight torch.Size([768, 768])
_extra_state torch.Size([22662])

# python test.py --steps 100000
weight torch.Size([768, 768])
_extra_state torch.Size([22683])