Closed chnsh closed 3 years ago
Sorry for the late reply. I haven't been able to reproduce this issue unfortunately (PyTorch 1.6, Torchmeta 1.5.0). Here is an example
import torch
import torch.nn as nn
from torchmeta.modules import MetaSequential, MetaLinear
from torchmeta.utils import gradient_update_parameters
model = MetaSequential(
MetaLinear(2, 3),
nn.ReLU(),
MetaLinear(3, 5)
)
train_inputs = torch.randn(7, 2)
train_outputs_1 = model(train_inputs)
inner_loss_1 = train_outputs_1.sum() # Dummy loss
params = gradient_update_parameters(model, inner_loss_1, first_order=True)
train_outputs_2 = model(train_inputs, params=params)
inner_loss_2 = train_outputs_2.sum() # Dummy loss
params = gradient_update_parameters(model, inner_loss_2, params=params, first_order=True)
test_inputs = torch.randn(7, 2)
test_outputs = model(test_inputs, params=params)
outer_loss = test_outputs.sum()
outer_loss.backward()
Can you give me a minimal example where you get this issue?
Hi,
Thank you for developing
pytorch-meta
- I have been trying to use the following method in the inner loop:When I use this code and call
.backward
,step
on the outer loss, I get an error as follows:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Have you encountered something like this?
Pytorch version: 1.5.1