pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 478 forks source link

gradient checkpoint cause bigger memory usage on GPU #3638

Open cicirori opened 2 years ago

cicirori commented 2 years ago

❓ Questions and Help

Recently I started testing GC performance on the GPU on the master version of pytorch and torch xla.

Unfortunately in consistent with my previous conclusions(https://github.com/pytorch/xla/issues/3455#issuecomment-1101056839), the current torch xla GC still seems have difficulty achieving the desired results, even with memory gains on some very simple cases.

Simple test program:

import argparse
import torch
import torch_xla.utils.checkpoint
import torch.utils.checkpoint
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

def run(grad_checkpoint, use_cuda=False):
    if use_cuda:
        device = 'cuda'
    else:
        device = xm.xla_device()

    model = torch.nn.ModuleList(
        [
            torch.nn.Sequential(
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
                torch.nn.Conv2d(1024, 1024, 1),
                torch.nn.ReLU(),
            )
            for _ in range(64)
        ]
    ).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0)

    for step in range(200):
        dummy_data = torch.zeros(64, 1024, 14, 14, device=device)
        optimizer.zero_grad()
        x = dummy_data
        for n_l, layer in enumerate(model):
            if n_l > 0 and grad_checkpoint:
                if use_cuda:
                    x = torch.utils.checkpoint.checkpoint(layer, x)
                else:
                    x = torch_xla.utils.checkpoint.checkpoint(layer, x)
            else:
                x = layer(x)
        dummy_loss = x.sum()
        dummy_loss.backward()
        optimizer.step()
        if not use_cuda:
            xm.mark_step()
            mem_info = xm.get_memory_info(device)
            mem_info['kb_used'] = mem_info['kb_total'] - mem_info['kb_free']
            print(f"step {step}, memory = {mem_info['kb_used']}")
        else:
            print(f"step {step}, memory = {torch.cuda.memory_summary()}")
    if not use_cuda:
        print(met.metrics_report())

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--grad_checkpoint", type=int, required=True)
    parser.add_argument("--use_cuda", type=int, required=True)
    args = parser.parse_args()
    run(args.grad_checkpoint, args.use_cuda)

Comparing GPU memory usage in four configs:

WITHOUT GC WITH GC
torch xla 11561 MiB 19753 MiB
torch native 7680 MiB 4634 MiB

running command:

TF_FORCE_GPU_ALLOW_GROWTH=true CUDA_VISIBLE_DEVICES=0 GPU_NUM_DEVICES=1 python3  ./test_gc.py --grad_checkpoint=0/1 --use_cuda=0/1
torch version: 1.12.0a0+git8abf37d
torch xla version: 5dae54bc53eb6c9a11eb4706fe01d1dfa557c14f
cicirori commented 2 years ago

simple_gc_test.zip

This zip contains the hlo dump results with GC enabled/disabled.

The number of model substructure cycles in the run that generated this dump was set to 4 instead of 64. But the conclusion is consistent, turning on GC on the GPU causes the memory to increase on this case.

JackCaoG commented 2 years ago

Hmm, is there a way for you to verify the peak memory usage and check if that reduced with gc?

JackCaoG commented 2 years ago

I dump the HLO for with checkpoint and without checkpoint case, will try to find someone to take a look.

with_gc.hlo.txt no_gc.hlo.txt

JackCaoG commented 2 years ago

I talked with Parker who is the author of the optimization_barrier HLO, I think my implementation of the optimzation_barrier has some flaw. I was doing

x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
x2 = layer1.fwd(x1)
(x2,x1) = opt_barrier(x2, x1)
x3 = layer2.fwd(x2)
(x3,x2) = opt_barrier(x3, x2)
...
grad2 = layer2.bwd(x2, grad3)
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.bwd(x0, grad1)

which does gurantee that repeated computation in the backward wait until the corresponding fwd function to finish. However there is nothing prevented repeated computation got moved right after the correponding fwd.

for example it can be

x1 = layer0.fwd(x0)
(x1,x0) = opt_barrier(x1, x0)
layer0.repeated_computation_for_backward
....
grad1 = layer1.bwd(x1, grad2)
grad0 = layer0.remaining_bwd(x0, grad1)
....

which can be unideal. I should really do it like

x1 = layer0.fwd(x0)
x2 = layer1.fwd(x1)
x3 = layer2.fwd(x2)
...
x2, grad3 = opt_barrier(x2, grad3)
grad2 = layer2.bwd(x2, grad3)
x1, grad2 = opt_barrier(x1, grad2)
grad1 = layer1.bwd(x1, grad2)
x0, grad1 = opt_barrier(x0, grad1)
grad0 = layer0.bwd(x0, grad1)

which will fully guarantee the execution order. Not sure how much it will impact the real memory usage but I will try to implement this soon. FYI @ronghanghu

JackCaoG commented 2 years ago

In another word, instead of binding the input and output of the fwd, I should bind the grad_input and input before performing the backward.

JackCaoG commented 2 years ago

Another point that parker raised was that because the example is so simple, the XLA compiler is super adversarial. It is trying really hard to use those unused cores for something! Let me first try to implement the change I propse above and see if that fixed the issue here.

ronghanghu commented 2 years ago

@JackCaoG I see, thanks for the update on this!

JackCaoG commented 2 years ago

@ronghanghu @cicirori Can you guys give https://github.com/pytorch/xla/pull/3721 a try?

ronghanghu commented 2 years ago

Thanks @JackCaoG -- I'll try this out!