metaopt / torchopt

TorchOpt is an efficient library for differentiable optimization built upon PyTorch.
https://torchopt.readthedocs.io
Apache License 2.0
544 stars 35 forks source link

[Question] Memory keep increase in MetaAdam due to gradient link #218

Open ycsos opened 5 months ago

ycsos commented 5 months ago

Required prerequisites

What version of TorchOpt are you using?

0.7.3

System information

>>> print(sys.version, sys.platform)
3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:36:13) [GCC 12.3.0] linux
>>> print(torchopt.__version__, torch.__version__, functorch.__version__)
0.7.3 2.3.0 2.3.0

Problem description

when use torchopt.MetaAdam and step some times, the memory use in gpu are continuously increase. It should not be, will you excute next step, the tensor create in the former step is no need should be release. I find the reason: metaOptimizer not detach the gradient link in optimizer. and former tensor was not release by torch due to dependency.

you can run the test code, the first one memory increase by step increase. and second one (I change the code to detach the grad link) the memory is stable when step increase: before:

image

after:

image

Reproducible example code

The Python snippets:

import torch
import torch.nn
import torch.nn.functional as F
import time
import torchopt

class test_nn(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = torch.nn.Parameter(torch.randn(768, 512, device="cuda")) # 768 * 512 * 4 / 1024 = 1536 KB
        self.b = torch.nn.Parameter(torch.randn(768, 768, device="cuda")) # 768 * 768 * 4 / 1024 = 2304 KB
        self.test_time = 10

    def forward(self):
        from torch.profiler import profile, record_function, ProfilerActivity
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=0, warmup=0, active=1, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler('./log_test/'),
            profile_memory=True,
            with_stack=True,
            record_shapes=True,
        ) as prof:
            def test_func1(a, b):
                with torch.enable_grad():
                    c = a * 2
                    d = torch.matmul(b, c)
                    return torch.sum(a + d)

            optimizer = torchopt.MetaAdam(self, lr=0.1)
            for _ in range(self.test_time):
                loss = test_func1(self.a, self.b)
                optimizer.step(loss)
                print(torch.cuda.max_memory_allocated())

def main():
    a = test_nn()
    a.forward()

if __name__ == "__main__":
    main()

Command lines:

python test.py

Traceback

current:
62526464
90054144
106309632
122827264
138558464
155600384
171331584
187587072
204104704
219835904

Expected behavior

should be :
57019392
60951552
61737984
63179776
63179776
63179776
63179776
63179776
63179776
63179776

Additional context

No response

XuehaiPan commented 5 months ago

I find the reason: metaOptimizer not detach the gradient link in optimizer. and former tensor was not release by torch due to dependency.

Hi @ycsos, this is intentional and it is the mechanism behind the explicit hyper-gradient.

when use torchopt.MetaAdam and step some times, the memory use in gpu are continuously increase.

You should not detach the computation graph in the inner loop. And you need to detach the graph in the outer loop using torchopt.stop_gradient. Refer to our tutorial for more detailed examples.

ycsos commented 5 months ago

I understand explicit hyper-gradient. mean. but what I want to do in only inference (no need save any activation) the model. I think only detach the outer loop is not enough, the tensor in inner loop is also linked, so when come to next step, the tensor cannot be release. in this pic, I only need forward, any good suggestion for detach the grad in inner loop?

ycsos commented 5 months ago

@XuehaiPan thank you for your reply, In my view torchopt.stop_gradient only detach the link for input tensor, but grad link between inner loop is also connected, like optimizer update parameters? that is totally right in training, but in inference, we don't need keep grad connect, and cause torch cannot release these tensor.

Benjamin-eecs commented 5 months ago

@XuehaiPan thank you for your reply, In my view torchopt.stop_gradient only detach the link for input tensor, but grad link between inner loop is also connected, like optimizer update parameters? that is totally right in training, but in inference, we don't need keep grad connect, and cause torch cannot release these tensor.

Hi, meta optimizers designed specifically for bilevel optimization algorithms. To meet your need, maybe you can directly use functional api.image

ycsos commented 5 months ago

@XuehaiPan thank you for your reply, In my view torchopt.stop_gradient only detach the link for input tensor, but grad link between inner loop is also connected, like optimizer update parameters? that is totally right in training, but in inference, we don't need keep grad connect, and cause torch cannot release these tensor.

Hi, meta optimizers designed specifically for bilevel optimization algorithms. To meet your need, maybe you can directly use functional api.image

thank you very much !now I understand the design of torchopt

XuehaiPan commented 5 months ago

that is totally right in training, but in inference, we don't need keep grad connect, and cause torch cannot release these tensor.

@ycsos I opened a PR to resolve this.

ycsos commented 5 months ago

that is totally right in training, but in inference, we don't need keep grad connect, and cause torch cannot release these tensor.

@ycsos I opened a PR to resolve this.

meta_optim = torchopt.MetaAdam(model, lr=0.1)

loss = compute_loss(model, batch)
with torch.no_grad():
    meta_optim.step(loss)

do you test the code? I have a question, you put step function under torch.no_grad(), and in torchopt/optim/meta/base.py Line 84 , you get the flat_new_params will be not requires_grad? and update to the container may cause next step wrong

torch.autograd.grad is no need under torch.enable_grad() ? so with step function maybe only need to add explict .requiresgrad() in update to model parameters