f-dangel / cockpit

Cockpit: A Practical Debugging Tool for Training Deep Neural Networks
MIT License
474 stars 24 forks source link

UserWarning #30

Open f-pfeiffer opened 6 months ago

f-pfeiffer commented 6 months ago

Hello, When I use cockpit with the HessMaxEV like

loss_fn = extend(nn.MSELoss(reduction='mean'))
individual_loss_fn = nn.MSELoss(reduction='none')
cockpit = Cockpit(self.q_net.parameters(), quantities=[quantities.HessMaxEV(schedules.linear(interval=3))])
...
with cockpit(
    train_run_grad_steps,
    info = {
        'batch_size': self.batch_size,
        'individual_losses': losses,
        'loss': loss,
        'optimizer': self.optimizer,
    },
):
    loss.backward(create_graph=cockpit.create_graph(train_run_grad_steps))

I get the following warning: UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak. (Triggered internally at ../torch/csrc/autograd/engine.cpp:1177.)

Do I have to reset the grad fields to None manually, or can I ignore this warning?

for param in params:
    param.grad = None
f-dangel commented 6 months ago

Hi,

thanks for reporting this, I did not encounter this warning before so I don't have very informed advice to offer. Which version of PyTorch are you using if I may ask? Maybe this potential source of reference cycles has been resolved in the latest version?

Best, Felix

f-pfeiffer commented 6 months ago

Hi Felix,

Here is an example script to reproduce the error:

import torch
import torch.nn as nn
from cockpit import Cockpit, quantities
from cockpit.utils import schedules
from backpack import extend

if __name__ == "__main__":
    model = nn.Sequential(
        nn.Linear(4, 2),
        nn.ReLU(),
        nn.Linear(2, 1)
    )
    loss_fn = extend(nn.MSELoss(reduction="mean"))
    individual_loss = nn.MSELoss(reduction="none")
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    cockpit = Cockpit(model.parameters(), quantities=[quantities.HessMaxEV(schedules.linear(interval=1))])

    # create toy data with 5 samples
    X = torch.randn(5, 4)
    y = torch.randn(5, 1)

    for epoch in range(10):
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        losses = individual_loss(y_pred, y)
        with cockpit(
            epoch,
            info = {
                'batch_size': 5,
                'individual_loss': losses,
                'loss': loss,
                'optimizer': optimizer,
            },
        ):
            loss.backward(create_graph=cockpit.create_graph(epoch))
        optimizer.step()

This is the output: /home/felix/anaconda3/envs/ma310/lib/python3.10/site-packages/torch/autograd/__init__.py:266: UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak. (Triggered internally at ../torch/csrc/autograd/engine.cpp:1177.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

I use the following software versions: torch: 2.2.2 cockpit: 1.0.2 backpack: 1.6.0 python: 3.10.13 OS: Ubuntu 22.04.4 LTS

Best, Felix

fsschneider commented 6 months ago

Hi,

I am not sure it addresses the root of the problem but can you try the above example again but extending also the model and the individual_loss_fn. This is for example done in our Basic Example in the documentation.

f-pfeiffer commented 6 months ago

Thanks for you advice.

I still get the warning. I ran the code with and without cockpit and I get the exact same results, so the warning can probably just be ignored.