facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

How does one implemented a parametrized meta-learner (like meta-lstm optimizer) in higher? #62

Open renesax14 opened 4 years ago

renesax14 commented 4 years ago

I wanted to implement the meta-lstm meta-learner in this paper OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING using higher but I found problems. I found that I cannot make it work without removing (what seems to be this crucial line):

https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/optim.py#L101

to:

        #self.param_groups = _copy.deepcopy(other.param_groups)
        self.param_groups = other.param_groups

I provide an extremely simplified self-contained implementation of something similar here:

https://gist.github.com/renesax14/8499e0314351ea4199a17e494bff5c4d

but I will copy paste here to keep the discussion in one place:

# base on the paper "OPTIMIZATION AS A MODEL FOR FEW-SHOT LEARNING": https://openreview.net/pdf?id=rJY0-Kcll

class EmptySimpleMetaLstm(Optimizer):

    def __init__(self, params, trainable_opt_model, trainable_opt_state, *args, **kwargs):
        defaults = {
            'trainable_opt_model':trainable_opt_model, 
            'trainable_opt_state':trainable_opt_state, 
            'args':args, 
            'kwargs':kwargs
        }
        super().__init__(params, defaults)

class SimpleMetaLstm(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_model']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                # get gradient as "data"
                g = g.detach() # gradients of gradients are not used (no hessians)
                ## very simplified version of meta-lstm meta-learner
                input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper
                lr = eta(input_metalstm).view(1)
                fg = 1 - lr # learnable forget rate
                ## update suggested by meta-lstm meta-learner
                p_new = fg*p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr

higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm)

def test_parametrized_inner_optimizer():
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from collections import OrderedDict

    ## training config
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
    copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
    episodes = 5
    nb_inner_train_steps = 5
    ## get base model
    base_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1, bias=False)),
        ('relu', nn.ReLU())
        ]))
    ## parametrization/mdl for the inner optimizer
    opt_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(3,1, bias=False)), # 3 inputs 1 for parameter, 1 for gradient, 1 for previous lr
        ('sigmoid', nn.Sigmoid())
        ]))
    ## get outer optimizer (not differentiable nor trainable)
    outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
    for episode in range(episodes):
        ## get fake support & query data (from a single task and 1 data point)
        spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
        ## get differentiable & trainable (parametrized) inner optimizer
        inner_opt = EmptySimpleMetaLstm(base_mdl.parameters(), trainable_opt_model={'eta': opt_mdl}, trainable_opt_state={'prev_lr': 0.9*torch.randn(1)})
        with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
            for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small  5)
                fmodel.train()
                # base/child model forward pass
                inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
                # inner-opt update
                diffopt.step(inner_loss)
            ## Evaluate on query set for current task
            qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
            qry_loss.backward() # for memory efficient computation
        ## outer update
        print(f'episode = {episode}')
        print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
        print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
        outer_opt.step()
        outer_opt.zero_grad()

if __name__ == '__main__':
    test_parametrized_inner_optimizer()
    print('Done \a')

"""
output when deep copy is uncommented (parametrized optimizer trains properly):
episode = 0
base_mdl.grad = tensor([[-0.0351]])
opt_mdl.grad = tensor([[0.0085, 0.0000, 0.0204]])
episode = 1
base_mdl.grad = tensor([[0.0311]])
opt_mdl.grad = tensor([[-0.0086, -0.0100,  0.0358]])
episode = 2
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 3
base_mdl.grad = tensor([[0.0066]])
opt_mdl.grad = tensor([[-0.0016,  0.0000, -0.0032]])
episode = 4
base_mdl.grad = tensor([[-0.0311]])
opt_mdl.grad = tensor([[0.0077, 0.0000, 0.0130]])
Done 
when deep copy is on (paremeters of inner optimizer are not train, sad!):
episode = 0
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 1
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 2
base_mdl.grad = tensor([[0.0069]])
opt_mdl.grad = None
episode = 3
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
episode = 4
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = None
Done
The deep copy line in higher I am referencing:
        self.param_groups = _copy.deepcopy(other.param_groups)
        #self.param_groups = other.param_groups
"""

crossposted: https://stackoverflow.com/questions/62459891/how-does-one-implemented-a-parametrized-meta-learner-in-pytorchs-higher-library

renesax14 commented 4 years ago

perhaps this can be implemented with override:

override (optional) – a dictionary mapping optimizer settings (i.e. those which would be passed to the optimizer constructor or provided within parameter groups) to either singleton lists of override values, or to a list of override values of length equal to the number of parameter groups. If a single override is provided for a keyword, it is used for all parameter groups. If a list is provided, the ith element of the list overrides the corresponding setting in the ith parameter group. This permits the passing of tensors requiring gradient to differentiable optimizers for use as optimizer settings.

renesax14 commented 4 years ago

Didn't work with override:

Exception has occurred: ValueError 
Mismatch between the number of override tensors for optimizer parameter trainable_opt_model and the number of parameter groups.
seems like it checks that these lengths match... 

def _apply_override(self, override: _OverrideType) -> None: 
for k, v in override.items(): 
# Sanity check 
if (len(v) != 1) and (len(v) != len(self.param_groups)):
renesax14 commented 4 years ago

Override version:

class EmptySimpleMetaLstm(Optimizer):

    def __init__(self, params, *args, **kwargs):
        defaults = { 'args':args, 'kwargs':kwargs}
        super().__init__(params, defaults)

class SimpleMetaLstm(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.override['trainable_opt_state']['prev_lr']
        simp_meta_lstm = self.override['trainable_opt_model']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                # get gradient as "data"
                g = g.detach() # gradients of gradients are not used (no hessians)
                ## very simplified version of meta-lstm meta-learner
                input_metalstm = torch.stack([p, g, prev_lr.view(1,1)]).view(1,3) # [p, g, prev_lr] note it's missing loss, normalization etc. see original paper
                lr = simp_meta_lstm(input_metalstm).view(1)
                fg = 1 - lr # learnable forget rate
                ## update suggested by meta-lstm meta-learner
                p_new = fg*p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr

higher.register_optim(EmptySimpleMetaLstm, SimpleMetaLstm)

####
####

def test_parametrized_inner_optimizer():
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from collections import OrderedDict

    ## training config
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
    copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
    episodes = 5
    nb_inner_train_steps = 5
    ## get base model
    base_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1, bias=False)),
        ('relu', nn.ReLU())
        ]))
    ## parametrization/mdl for the inner optimizer
    opt_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(3,1, bias=False)), # 3 inputs [p, g, prev_lr] 1 for parameter, 1 for gradient, 1 for previous lr
        ('sigmoid', nn.Sigmoid())
        ]))
    ## get outer optimizer (not differentiable nor trainable)
    outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
    for episode in range(episodes):
        ## get fake support & query data (from a single task and 1 data point)
        spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
        ## get differentiable & trainable (parametrized) inner optimizer
        override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
        inner_opt = EmptySimpleMetaLstm(base_mdl.parameters())
        with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads, override=override) as (fmodel, diffopt):
            for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small  5)
                fmodel.train()
                # base/child model forward pass
                inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
                # inner-opt update
                diffopt.step(inner_loss)
            ## Evaluate on query set for current task
            qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
            qry_loss.backward() # for memory efficient computation
        ## outer update
        print(f'episode = {episode}')
        print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
        print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
        outer_opt.step()
        outer_opt.zero_grad()
renesax14 commented 4 years ago

the real solution is if I could pass an arbitrary dictionary to a differentiable optimizer and if I could do whatever I wanted with it.

renesax14 commented 4 years ago

Perhaps just creating my own field once the diffopt is created is all I need?

so this line:

            diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }

whole:

def test_parametrized_inner_optimizer():
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from collections import OrderedDict

    ## training config
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    track_higher_grads = True # if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. False during test time for efficiency reasons
    copy_initial_weights = False # if False then we train the base models initial weights (i.e. the base model's initialization)
    episodes = 5
    nb_inner_train_steps = 5
    ## get base model
    base_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1, bias=False)),
        ('act', nn.ReLU())
        ]))
    ## parametrization/mdl for the inner optimizer
    opt_mdl = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(3,1, bias=False)), # 3 inputs [p, g, prev_lr] 1 for parameter, 1 for gradient, 1 for previous lr
        ('act', nn.LeakyReLU())
        ]))
    ## get outer optimizer (not differentiable nor trainable)
    outer_opt = optim.Adam([{'params': base_mdl.parameters()},{'params': opt_mdl.parameters()}], lr=0.01)
    for episode in range(episodes):
        ## get fake support & query data (from a single task and 1 data point)
        spt_x, spt_y, qry_x, qry_y = torch.randn(1), torch.randn(1), torch.randn(1), torch.randn(1)
        ## get differentiable & trainable (parametrized) inner optimizer
        inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
        with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
            diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
            for i_inner in range(nb_inner_train_steps): # this current version implements full gradient descent on k_shot examples (which is usually small  5)
                fmodel.train()
                # base/child model forward pass
                inner_loss = 0.5*((fmodel(spt_x) - spt_y))**2
                # inner-opt update
                diffopt.step(inner_loss)
            ## Evaluate on query set for current task
            qry_loss = 0.5*((fmodel(qry_x) - qry_y))**2
            qry_loss.backward() # for memory efficient computation
        ## outer update
        print(f'episode = {episode}')
        print(f'base_mdl.grad = {base_mdl.fc.weight.grad}')
        print(f'opt_mdl.grad = {opt_mdl.fc.weight.grad}')
        outer_opt.step()
        outer_opt.zero_grad()
renesax14 commented 4 years ago

but grads are zero...?

episode = 0
base_mdl.grad = tensor([[-0.4019]])
opt_mdl.grad = tensor([[0.0165, 0.7733, 0.2050]])
episode = 1
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 2
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 3
base_mdl.grad = tensor([[0.]])
opt_mdl.grad = tensor([[0., 0., 0.]])
episode = 4
base_mdl.grad = tensor([[-0.1466]])
opt_mdl.grad = tensor([[ 0.0300,  0.0081, -0.0763]])
Done 
renesax14 commented 4 years ago

I think this is all I need:

        inner_opt = EmptySimpleMetaLstm( base_mdl.parameters() )
        with higher.innerloop_ctx(base_mdl, inner_opt, copy_initial_weights=copy_initial_weights, track_higher_grads=track_higher_grads) as (fmodel, diffopt):
            diffopt.override = {'trainable_opt_model': opt_mdl, 'trainable_opt_state': {'prev_lr': 0.9*torch.randn(1)} }
egrefen commented 4 years ago

Okay I've read through what you've done, and skimmed the paper, and I think you're getting a little off track. I am confident that higher supports meta-LSTMs as is without needing to change stuff in higher.optim. If it turns out we need to change something there, we will, but let's try and collaboratively solve the problem first.

Let's pretend for a second that higher doesn't exist, and not worry about how to deal with training a meta-lstm for now. Let's just implement a class which implements eq (2) of section 3.1 of the paper.

import torch
import torch.nn.functional as F
from torch import nn, optim

class MetaLSTM(optim.Optimizer):
    r"""Implements a meta-LSTM optimizer."""

    def __init__(self, params):
        # continue implementation here

I could do this myself, but I'll confess I'm a little tight on time. If you can give it a shot and revert the assignment to me. Again, don't worry about anything other than defining the meta-parameters of the meta-LSTM, and implementing the "forward pass" in the step method. As a general hint, you'll need to turn the parameters of each group into a single vector, compute f_t and i_t, and then reshape/split the updated parameters from that group into their original form and do the in-place assignment. This should be fairly easy, but ping me if you get stuck.

egrefen commented 4 years ago

Hello @renesax14. Just checking if you have any interest in providing a non-higher implementation of what the MetaLSTM does at test time (without second order gradients, see comment above). If you provide that, I can help you write the DiffOpt version with higher. If not, I will close this issue in one month.