learnables / learn2learn

A PyTorch Library for Meta-learning Research
http://learn2learn.net
MIT License
2.61k stars 350 forks source link

Adding a new parameter to meta-sgd #358

Closed JeremyFisher closed 1 year ago

JeremyFisher commented 1 year ago

Hi,

thanks for a great library! I wanted to use the meta-sgd implementation and was wondering what would be the best way of adding a learnable scalar.

Looking at your meta-sgd example: https://github.com/learnables/learn2learn/blob/master/examples/vision/meta_mnist.py

lets say the forward pass has to change from:

    def forward(self, x):
        x = first_part(x)
        return second_part(x)

to:

    def forward(self, x):
        x = first_part(x)
        x = some_function(x, my_learnable_param)
        return second_part(x)

Where would you define my_learnable_param and what would be the proper way of updating it?

seba-1511 commented 1 year ago

Hello @JeremyFisher,

If I understand correctly my_learnable_param is not updated during fast-adaptation, right? The simplest approach in your case is to subclass the MetaSGD wrapper, add an extra parameter for my_learnable_param, and explicitly prevent adapting it when .adapt(loss) is called.

JeremyFisher commented 1 year ago

Hi @seba-1511 , thank you for your quick reply.

Yes, my_learnable param is not updated during fast-adaptation, correct. Here is my current approach:

the learnable parameter meta_noise is in My_d module (lets say the module adds noise to a layer and the learnable param controls the magnitude of the noise):

class My_d(nn.Module):                                                                                                                                             
    def __init__(self, init_val) -> None:                                                                            
        super(My_d, self).__init__()                                            
        self.meta_noise = nn.Parameter(th.tensor(init_val))                    

    def forward(self, input: th.Tensor) -> th.Tensor:                                                                        
        noise = th.randn(input.size(),device=input.device)\                     
                * self.meta_noise                                                                            
        return input * noise 

Now, My_d module will be added to the subclass of MetaSGD:

class MetaSGD_sub(MetaSGD):                                                     
    ''' subclass of l2l MetaSGD which has a learnable module/parameter '''                                                                         
    def __init__(self, model, lr=1.0, first_order=False, lrs=None):                                                        
        super(MetaSGD_sub, self).__init__(model, lr, first_order, lrs)          

        self.my_d1 = My_d(0.001)

Finally, the forward pass will look like so:

meta_model = MetaSGD_sub(model, lr=0.001) 
task_model = meta_model.clone()
loss = task_model(data, meta_model.my_d1)
task_model.adapt(loss)
loss = task_model(data, meta_model.my_d1)
loss.backward()
optimizer.step() 

Does it make sense? I saw that meta_noise is being updated after optimizer.step() and does not change during task_model.adapt().

Surprisingly, I didn't have to changing anything in the adapt code.

seba-1511 commented 1 year ago

Yes, at a high-level this looks right. To be 100% sure, I recommend putting breakpoints before and after adaptation to make sure your new param is not adapted.

Closing, feel free to open if there’s an issue.

JeremyFisher commented 1 year ago

Hey @seba-1511,

it seems there is a small issue. This approach was working for this particular noise adding function, but then if I do a simple modification:

class My_d(nn.Module):                                                                                                                                             
    def __init__(self, init_val) -> None:                                                                            
        super(My_d, self).__init__()                                            
        self.meta_noise = nn.Parameter(th.tensor(init_val))                    

    def forward(self, input: th.Tensor) -> th.Tensor:                                                                        
        #noise = th.randn(input.size(),device=input.device)\                     
        #        * self.meta_noise                                                                            
        rates = th.rand(input.shape, device=input.device) * self.meta_noise 
        noise = th.poisson(rates)
        return input * noise 

suddenly gradients are 0 during the training stage, ie self.meta_model.my_d1.meta_noise.grad is always 0 before calling optimizer.step().

In this particular case, if my learnable parameter meta_noise is part of meta_model / meta_sgd / meta_sgd subclass (it is not part of task_model), do I still need to worry about torch.poisson not being differentiable?