learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.66k stars 353 forks source link

Issue in the example #308

Closed nsanghi closed 2 years ago

nsanghi commented 2 years ago

https://github.com/learnables/learn2learn/blob/fab0e11661bf8b0f99d7d85a9881d890393c2811/learn2learn/utils/__init__.py#L174

There are two issues in this example a) small typo: first layer has a missing "nn.". It should read as nn.Linear(20, 10) and not Linear(20, 10) b) Line 176 should have keep_requires_grad=True so that the clone module is detached from net and also has requires_grad=True. Keeping the default of False makes the parameters of clone network not require gradient. And in that case calling error.backward{} will fail with the exception:

element 0 of tensors does not require grad and does not have a grad_fn

If you want to clone a module, detach it and then run back prop only through the detached module, I had to use make this flag True. Complete working code is

net = nn.Sequential(nn.Linear(20, 10), nn.ReLU(), nn.Linear(10, 2))
clone = clone_module(net)
detach_module(clone, keep_requires_grad=True)
error = loss(clone(X), y)
error.backward()  # Gradients are back-propagated on clone, not net.

In the above example, backward{} in last line populates the gradient only for the clone and not for net

seba-1511 commented 2 years ago

Thanks for spotting this issue, @nsanghi. Do you want to open a PR for it?