learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

About the meaning of `allow_nograd` in `MAML.adapt()` #380

Closed gwwo closed 1 year ago

gwwo commented 1 year ago

Hi,

I found the implementation of allow_nograd argument seems to contradict its meaning.

https://github.com/learnables/learn2learn/blob/0b9d3a3d540646307ca5debf8ad9c79ffe975e1c/learn2learn/algorithms/maml.py#L109-L129

According to its naming and comment, I believe it's meant to enable the model parameters with requires_grad=False to be fast adapted, thus commputing and later backpropagating through the gradients w.r.t them to contribute to the meta-gradients w.r.t the model parameters with requires_grad=True.

However in the implementation, the if allow_nograd: branch is actually filtering out those nograd parameters, i.e. setting gradient = None for them to be skipped in maml_update

https://github.com/learnables/learn2learn/blob/0b9d3a3d540646307ca5debf8ad9c79ffe975e1c/learn2learn/algorithms/maml.py#L138-L166

seba-1511 commented 1 year ago

Hello @gwwo,

The allow_nograd argument supports fast-adapting models which have non-differentiable parameters. Those parameters are never adapted (else you could just let them have requires_grad = True and let their meta-learning rate be 0) but need a little bit of extra care. The flag provides this care, and it is False by default because we want to be explicit for when these parameters are expected in the graph vs not.

Hope this helps!