f-dangel / backpack

BackPACK - a backpropagation package built on top of PyTorch which efficiently computes quantities other than the gradient.
https://backpack.pt/
MIT License
555 stars 55 forks source link

memory leakage problem when loss.backward() #258

Open SJShin-AI opened 2 years ago

SJShin-AI commented 2 years ago

Hi.

How can i solve memory leakage problem on loss.backward? Actually my code is a bit complex, which makes it hard to provide whole.

The main memory burden comes from here. Exactly after the execution of loss.backward(). The memory increases for every iteration, which leads to OOM.

loss = bce_extended(logits, y).sum()
with backpack(BatchGrad()):
  if real_sample:
      loss.backward(inputs = list(model.parameters()))

I also tried with disable(): which prevent the memory leak problem. However, it cannot be implemented with with backpack(BatchGrad()): when i want to get the per-sample gradient.

f-dangel commented 2 years ago

Hi,

thanks for reporting. I understand your full code might be complex to share, but could you provide a minimal example that reproduces the leak? This would be extremely helpful for debugging.

The above code looks fine to me, so it would be interesting to have more details when the issue occurs.

Best, Felix

SJShin-AI commented 2 years ago

I tried hard to construct the minimal example. Sorry for providing the still lengthy code, which is also difficult to understand the learning framework.

My task is dataset condensation, which transforms the random noise into the informative input by directly updating it via gradient on random noise. I am trying to construct the learning objective for the given image by matching (the loss gradient between training dataset and synthetic input) and (the loss gradient variance between them, which needs the access into per-sample gradient for variance computation.)

I also found out that the memory leakage arises on loss.backward() part on get_grad function. Please do not care much about the semantics of code, where some parts are pseudo produced. I am really grateful for you. thanks.

from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
import torch
from backpack import backpack, extend
from backpack.extensions import BatchGrad
from backpack.utils.examples import load_one_batch_mnist
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader
# download path 정의
download_root = './MNIST_DATASET'
import torchvision.transforms as transforms

# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (1.0,))
])
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
# option 값 정의
batch_size = 512
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True)
device = torch.device("cuda:0")

def compute_distance_grads_var(dict_grad_1,dict_grad_2):
    penalty = 0
    penalty += l2_between_lists(dict_grad_1, dict_grad_2)
    return penalty

def l2_between_lists(list_1, list_2):
    assert len(list_1) == len(list_2)
    return (
        torch.cat(tuple([t.view(-1) for t in list_1])) -
        torch.cat(tuple([t.view(-1) for t in list_2]))
    ).pow(2).sum()

def dist(x, y, method='mse'):
    """Distance objectives
    """
    if method == 'mse':
        dist_ = (x - y).pow(2).sum()
    elif method == 'l1':
        dist_ = (x - y).abs().sum()
    elif method == 'l1_mean':
        n_b = x.shape[0]
        dist_ = (x - y).abs().reshape(n_b, -1).mean(-1).sum()
    elif method == 'cos':
        x = x.reshape(x.shape[0], -1)
        y = y.reshape(y.shape[0], -1)
        dist_ = torch.sum(1 - torch.sum(x * y, dim=-1) /
                          (torch.norm(x, dim=-1) * torch.norm(y, dim=-1) + 1e-6))
    elif method == 'l2_mean':
        dist_ = torch.norm(x-y, 2)
    return dist_

def get_grads(logits, y,model,bce_extended,real_sample):
    loss = bce_extended(logits, y).sum()
    with backpack(BatchGrad()):
        if real_sample:
            loss.backward()
        else:
            loss.backward(create_graph=True)
    grads_mean = []
    dict_grads_batch = []

    for name, weights in model.named_parameters():
        if real_sample:
            grads_mean.append(weights.grad.detach().clone())
            dict_grads_batch.append(weights.grad_batch.detach().clone().view(weights.grad_batch.size(0), -1))
        else:
            grads_mean.append(weights.grad.clone())
            dict_grads_batch.append(weights.grad_batch.clone().view(weights.grad_batch.size(0), -1))

    return grads_mean, dict_grads_batch

for i in range(100):
    model = Sequential(Flatten(), Linear(784, 128), Linear(128, 10))  # I added an additional layer here
    lossfunc = CrossEntropyLoss()
    model = extend(model)
    lossfunc = extend(lossfunc)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for batch_idx, (x, target) in enumerate(train_loader):

        x_syn = torch.rand(x.shape, requires_grad=True, device="cuda:0")
        y_syn = torch.ones_like(target)
        y_syn = y_syn.to(device)
        optimizer_alpha = torch.optim.Adam([x_syn], lr=1e-3)

        x = x.to(device)
        target = target.to(device)
        loss_model = lossfunc(model(x),target)

        grad, grad_batch = get_grads(model(x), target, model, lossfunc,real_sample=True)
        grad_syn, grad_batch_syn = get_grads(model(x_syn), y_syn, model, lossfunc,real_sample=False)
        loss =0
        for i in range(len(grad)):
            loss +=dist(grad[i], grad_syn[i], method='l2_mean')
        loss += compute_distance_grads_var(grad_batch,grad_batch_syn)
        optimizer.zero_grad()
        loss.backward()
        optimizer_alpha.step()
SJShin-AI commented 2 years ago

I think that this implementation should provide a lot of utilities on per-sample gradient computation. One of famous utilization is "https://arxiv.org/abs/2109.02934", which updates the model parameter based on gradient variance matching by leveraging backpack.

The main difference is that my code tries to learn gradient on synthetic images, rather than the gradient on model parameter.

fKunstner commented 2 years ago

Adding a zero_grad on the inputs gradients seem to fix the issue.

As in changing the last few lines of the above script

        optimizer.zero_grad()
        loss.backward()
        optimizer_alpha.step()

to

        optimizer.zero_grad()
        loss.backward()
        optimizer_alpha.step()
        optimizer_alpha.zero_grad() # ---

It should not change its behavior as optimizer_alpha is re-initialized at each iteration.

Not sure why it's not garbage collected though.

SJShin-AI commented 2 years ago

Hi, i checked your solution and it slightly reduces the memory leak. However, the memory increases in very small increments as the iteration goes on. It seems like some parts are still not garbage collected..!

f-dangel commented 2 years ago

Hi,

just wanted to bring it up because I saw there is a .backward(..., create_graph=True) in your code: There's a memory leak when using full_backward_hooks with create_graph=True in PyTorch (#82528). You could try installing PyTorch with the fix (#82788) to see if that's causing the memory leak.

SJShin-AI commented 2 years ago

Hi,

first, thanks for your generous reply on my question. I really appreciate your sincere reply with efforts. As you referred, the cause of the memory leakage problem seems to be pytorch.

So, i upgraded pytorch with preview (nightly) version. (Speaking of reference, the _grad_input_padding currently used by backpack was depreciated in the corresponding version of pytorch, so adjustments were needed to implement the code.) Afterwards, i got new type of error from the code below:

from torch.nn import CrossEntropyLoss, Flatten, Linear, Sequential
import torch

from backpack import backpack, extend
from backpack.extensions import BatchGrad
from torchvision.datasets import MNIST
from torch.utils.data import TensorDataset, DataLoader

download_root = './MNIST_DATASET'
import torchvision.transforms as transforms

# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (1.0,))
])
train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
# option 값 정의
batch_size = 512
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=batch_size,
                         shuffle=True)
device = torch.device("cuda:0")

def compute_distance_grads_var(dict_grad_1,dict_grad_2):
    penalty = 0
    penalty += l2_between_lists(dict_grad_1, dict_grad_2)
    return penalty

def l2_between_lists(list_1, list_2):
    assert len(list_1) == len(list_2)
    return (
        torch.cat(tuple([t.view(-1) for t in list_1])) -
        torch.cat(tuple([t.view(-1) for t in list_2]))
    ).pow(2).sum()

def dist(x, y, method='mse'):
    """Distance objectives
    """
    if method == 'mse':
        dist_ = (x - y).pow(2).sum()
    elif method == 'l1':
        dist_ = (x - y).abs().sum()
    elif method == 'l1_mean':
        n_b = x.shape[0]
        dist_ = (x - y).abs().reshape(n_b, -1).mean(-1).sum()
    elif method == 'cos':
        x = x.reshape(x.shape[0], -1)
        y = y.reshape(y.shape[0], -1)
        dist_ = torch.sum(1 - torch.sum(x * y, dim=-1) /(torch.norm(x, dim=-1) * torch.norm(y, dim=-1) + 1e-6))
    elif method == 'l2_mean':
        dist_ = torch.norm(x-y, 2)
    return dist_

def get_grads(logits, y,model,bce_extended,real_sample):
    loss = bce_extended(logits, y).sum()
    with backpack(BatchGrad(),debug=True):
        if real_sample:
            # loss.backward(inputs=list(model.parameters()))
            loss.backward(inputs=list(model.parameters()))
        else:
            loss.backward(inputs=list(model.parameters()),create_graph = True)

    grads_mean = []
    dict_grads_batch = []

    for name, weights in model.named_parameters():
        if real_sample:
            grads_mean.append(weights.grad.detach().clone())
            dict_grads_batch.append(weights.grad_batch.detach().clone().view(weights.grad_batch.size(0), -1))
        else:
            grads_mean.append(weights.grad)
            dict_grads_batch.append(weights.grad_batch.view(weights.grad_batch.size(0), -1))

    return grads_mean, dict_grads_batch

for i in range(100):
    model = Sequential(Flatten(), Linear(784, 128), Linear(128, 10))  # I added an additional layer here
    lossfunc = CrossEntropyLoss()
    model = extend(model)
    lossfunc = extend(lossfunc)
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    for batch_idx, (x, target) in enumerate(train_loader):
        x_syn = torch.rand(x.shape, requires_grad=True, device="cuda:0")
        y_syn = torch.ones_like(target)
        y_syn = y_syn.to(device)
        optimizer_alpha = torch.optim.Adam([x_syn], lr=1e-3)

        x = x.to(device)
        target = target.to(device)
        grad, grad_batch = get_grads(model(x), target, model, lossfunc,real_sample=True)
        grad_syn, grad_batch_syn = get_grads(model(x_syn), y_syn, model, lossfunc,real_sample=False)
        loss =0
        for i in range(len(grad)):
            loss +=dist(grad[i], grad_syn[i], method='l2_mean')
        loss += compute_distance_grads_var(grad_batch,grad_batch_syn)
        loss.backward()
        optimizer_alpha.step()
        optimizer_alpha.zero_grad()

Traceback (most recent call last): File "condense_example.py", line 125, in loss.backward() File "/home/aailab/tmdwo0910/.local/lib/python3.8/site-packages/torch/_tensor.py", line 484, in backward torch.autograd.backward( File "/home/aailab/tmdwo0910/.local/lib/python3.8/site-packages/torch/autograd/init.py", line 191, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass File "/home/aailab/tmdwo0910/.local/lib/python3.8/site-packages/torch/utils/hooks.py", line 100, in hook raise RuntimeError("Module backward hook for grad_input is called before " RuntimeError: Module backward hook for grad_input is called before the grad_output one. This happens because the gradient in your nn.Module flows to the Module's input without passing through the Module's output. Make sure that the output depends on the input and that the loss is computed based on the output.

It seems like the error arises because model backpropagating twice in the graph, which is essential behavior for my task...

The link below shows that wandb was a cause for the same error, which needs an access into the backward function. (https://discuss.pytorch.org/t/runtimeerror-module-backward-hook-for-grad-input-is-called-before-the-grad-output-one-this-happens-because-the-gradient-in-your-nn-m-odule-flows-to-the-modules-input-without-passing-through-the-modules-output/119763.)

Thanks.

f-dangel commented 2 years ago

Hi again, thanks for the report.

one more (rather miscellaneous) thing you might want to try is the following: You should be able to use retain_graph=True rather than create_graph=True to be able to get the gradient for synthetic samples, right? Does this still lead to the above exception?

SJShin-AI commented 2 years ago

Hi. Thanks for your kind suggestion. Actually i already have tried that. The choice leads to the following error below:

File "condense_example.py", line 102, in <module>
    loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I conjecture that create_graph = True should be followed to be able to calculate the second order derivative, which is crucial for calculating the gradient for synthetic samples...

f-dangel commented 2 years ago

You're correct. I missed the second loss.backward() call that differentiates through loss which contains gradient terms.

SJShin-AI commented 2 years ago

I think the backpack is the only way to realistically calculate the gradient variance of each model parameter through a quick computation of the per-sample gradient.

However, the problem from my code seems to be an inherent problem in pytorch rather than a problem from backpack. I will keep track of the memory leakage problem to find a solution. However, it would be a great help if you could give me some helpful advice.

f-dangel commented 2 years ago

Keep me posted about the memory leak problem.

You might also want to try out functorch (it should be possible to integrate it into your existing PyTorch code without too much effort), or jax. They can also compute individual gradients that you can use to get the variance.

SJShin-AI commented 2 years ago

Thanks. I appreciate your suggestion. Did you try any kind of time comparison between backpack and functorch?

I have heard that functorch is still expensive to compute per-sample gradient, which takes too much time...

f-dangel commented 2 years ago

Sadly I don't have any data yet how BackPACK compares to functorch in terms of runtime, e.g. for computing individual gradients.

But I think you will be able to port your existing code to functorch with relatively few changes to try if it works fast enough for your needs.