Closed JeremyFisher closed 2 years ago
Hello @JeremyFisher,
If I understand correctly my_learnable_param
is not updated during fast-adaptation, right? The simplest approach in your case is to subclass the MetaSGD wrapper, add an extra parameter for my_learnable_param
, and explicitly prevent adapting it when .adapt(loss)
is called.
Hi @seba-1511 , thank you for your quick reply.
Yes, my_learnable param is not updated during fast-adaptation, correct. Here is my current approach:
the learnable parameter meta_noise
is in My_d
module (lets say the module adds noise to a layer and the learnable param controls the magnitude of the noise):
class My_d(nn.Module):
def __init__(self, init_val) -> None:
super(My_d, self).__init__()
self.meta_noise = nn.Parameter(th.tensor(init_val))
def forward(self, input: th.Tensor) -> th.Tensor:
noise = th.randn(input.size(),device=input.device)\
* self.meta_noise
return input * noise
Now, My_d
module will be added to the subclass of MetaSGD
:
class MetaSGD_sub(MetaSGD):
''' subclass of l2l MetaSGD which has a learnable module/parameter '''
def __init__(self, model, lr=1.0, first_order=False, lrs=None):
super(MetaSGD_sub, self).__init__(model, lr, first_order, lrs)
self.my_d1 = My_d(0.001)
Finally, the forward pass will look like so:
meta_model = MetaSGD_sub(model, lr=0.001)
task_model = meta_model.clone()
loss = task_model(data, meta_model.my_d1)
task_model.adapt(loss)
loss = task_model(data, meta_model.my_d1)
loss.backward()
optimizer.step()
Does it make sense? I saw that meta_noise
is being updated after optimizer.step()
and does not change during task_model.adapt()
.
Surprisingly, I didn't have to changing anything in the adapt
code.
Yes, at a high-level this looks right. To be 100% sure, I recommend putting breakpoints before and after adaptation to make sure your new param is not adapted.
Closing, feel free to open if there’s an issue.
Hey @seba-1511,
it seems there is a small issue. This approach was working for this particular noise adding function, but then if I do a simple modification:
class My_d(nn.Module):
def __init__(self, init_val) -> None:
super(My_d, self).__init__()
self.meta_noise = nn.Parameter(th.tensor(init_val))
def forward(self, input: th.Tensor) -> th.Tensor:
#noise = th.randn(input.size(),device=input.device)\
# * self.meta_noise
rates = th.rand(input.shape, device=input.device) * self.meta_noise
noise = th.poisson(rates)
return input * noise
suddenly gradients are 0 during the training stage, ie self.meta_model.my_d1.meta_noise.grad
is always 0 before calling optimizer.step()
.
In this particular case, if my learnable parameter meta_noise
is part of meta_model / meta_sgd / meta_sgd subclass (it is not part of task_model), do I still need to worry about torch.poisson
not being differentiable?
Hi,
thanks for a great library! I wanted to use the meta-sgd implementation and was wondering what would be the best way of adding a learnable scalar.
Looking at your meta-sgd example: https://github.com/learnables/learn2learn/blob/master/examples/vision/meta_mnist.py
lets say the forward pass has to change from:
to:
Where would you define
my_learnable_param
and what would be the proper way of updating it?