facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

example of trainable optimizer? #32

Closed renesax14 closed 4 years ago

renesax14 commented 4 years ago

I suggest a full that implements the optimizer but a trainable step size could be a good example too...

https://discuss.pytorch.org/t/implement-a-meta-trainable-step-size/70396

egrefen commented 4 years ago

@denisyarats wrote examples of optimizing the learning rates for the GIMLi paper. Pinging him here so that maybe we can integrate some of that code in the examples folder (if he has time).

If you mean learning the entire optimizer as a parametric function, as in e.g. learning to learn by gradient descent by gradient descent then that would make an excellent example to add to the examples folder. We would welcome a pull request doing this, but don't have the cycles to do it ourselves at the moment.

egrefen commented 4 years ago

Closing this issue for now, but we always welcome pull requests providing new examples!

renesax14 commented 4 years ago

@denisyarats do you have an example of learning the learning rate that we could add to the examples in this library?

denisyarats commented 4 years ago

yes, I do have this example, you can find it here: https://github.com/denisyarats/densenet_cifar10

Feel free to integrate it into the examples folder of higher.

renesax14 commented 4 years ago

yes, I do have this example, you can find it here: https://github.com/denisyarats/densenet_cifar10

Feel free to integrate it into the examples folder of higher.

will check it out thanks!

renesax14 commented 4 years ago

@egrefen sorry for bothering you again but I thought I was so close but I still got an error. It thinks my step size NN is not in the graph but it is because of this line of code:

                p_new = p + lr*g
                group['params'][p_idx] = p_new

but somehow that is not enough to have gradients...

Full script self contained script:

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

from collections import OrderedDict

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, eta, prev_lr):
        defaults = {'eta':eta, 'prev_lr':prev_lr}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['prev_lr'] = lr

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10) )
        ]))

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1)),
        ('sigmoid', nn.Sigmoid())
    ]))
    inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden)
    meta_params = itertools.chain(child_model.parameters(),eta.parameters())
    #meta_params = itertools.chain(eta.parameters(),[hidden])
    meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    print()
    nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
        nb_inner_steps = 3
        #with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
        with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                if inner_i >= nb_inner_steps:
                    break
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'--> inner_i = {inner_i}')
                print(f'inner_loss^<{inner_i}>: {inner_loss}')
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}') 
                diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
                print()
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            make_dot(meta_loss).render('meta_loss',format='png')
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            print(f'eta.fc.weight = {eta.fc.weight.grad}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )

if __name__ == "__main__":
    main()
    print('---> Done\a')

notice the None's:

Files already downloaded and verifiedFiles already downloaded and verified
-> hidden = tensor([[0.8459]], requires_grad=True)

--> inner_i = 0
inner_loss^<0>: 2.2696359157562256
lr^<-1> = tensor([[0.8459]], requires_grad=True)
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)

--> inner_i = 1
inner_loss^<1>: 2.0114920139312744
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)

--> inner_i = 2
inner_loss^<2>: 2.3866422176361084
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
lr^<2> = tensor([0.0717], grad_fn=<MulBackward0>)

----> outer_i = 0
-> outer_loss/meta_loss^<0>: 4.021303176879883
child_model.fc.weight.grad = None
hidden.grad = None
eta.fc.weight = None
---> Done

I just saw denis responded so I will check his code too...

renesax14 commented 4 years ago

@egrefen Sorry for the spam, but this has to be some sort of bug because when I add all the parameters inside the _update, call backward on the sum and then print the gradient the gradients I expect to be non-zero are indeed non-zero:

==> hidden.grad = tensor([[0.0373]])
==> eta.fc.weight.grad = tensor([[-0.0882]])

but when I do it otuside of _update (in the inner loop) and do diffopt.step(inner_loss) I get they are incorrectly None:

===> hidden.grad = None
===> eta.fc.weight.grad = None

This must be some sort of bug somewhere, because I have not done anything with the weights after step and they should be the same as they were inside the _update function.


For reference I will paste the new code with the new print statements:

'''
Single task MAML:

MAML: min_{theta} sum_t L^val( theta - eta* Grad L^train(theta) )

T-step MAML: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) ~ min_{theta} sum_t L^val( argmin L^train(theta) )
Innerloop: theta^{T} - eta* Grad L^train(theta^{T}) ~ argmin L^train(theta)

single task MAML: min_{theta} L^val( theta - eta* Grad L^train(theta) )

based on MAML example: https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py
'''

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, eta, prev_lr, hidden, meta_opt):
        defaults = {'eta':eta, 'prev_lr':prev_lr, 'hidden':hidden, 'meta_opt':meta_opt}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        print(f'==> hidden.grad = {hidden.grad}')
        print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        print()

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))
    #meta_params = itertools.chain(child_model.parameters(),eta.parameters(),[hidden])
    meta_params = itertools.chain(eta.parameters(),[hidden])
    meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
    inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden, hidden=hidden, meta_opt=meta_opt)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    print()
    nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
        nb_inner_steps = 3
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
        #with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                meta_opt.zero_grad()
                if inner_i >= nb_inner_steps:
                    break
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'--> inner_i = {inner_i}')
                print(f'inner_loss^<{inner_i}>: {inner_loss}')
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}') 
                diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
                ##
                p_tot = sum([ p.sum() for p in fmodel.parameters() ])
                p_tot.backward()
                print(f'===> hidden.grad = {hidden.grad}')
                print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
                print()
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print()
            p_tot = sum([ p.sum() for p in fmodel.parameters() ])
            p_tot.backward()
            print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            #meta_loss = meta_loss + inner_loss
            #make_dot(meta_loss).render('meta_loss',format='png')
            #meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            print(f'eta.fc.weight = {eta.fc.weight.grad}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )

if __name__ == "__main__":
    main()
    print('---> Done\a')
renesax14 commented 4 years ago

perhaps it's this line of code

https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/optim.py#L260

right after update:

        self._update(grouped_grads)

        new_params = params[:]
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            for p, index in zip(group['params'], mapping):
                if self._track_higher_grads:
                    new_params[index] = p
                else:
                    new_params[index] = p.detach().requires_grad_()

        if self._fmodel is not None:
            self._fmodel.update_params(new_params)
renesax14 commented 4 years ago

probably this function, since it's the only part that re-assigns the attributes:

https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/patch.py#L297

egrefen commented 4 years ago

Re-opening this issue so I remember to check it this week (if I can find the time).

renesax14 commented 4 years ago

I feel so close...but something about the way fmodules (and fmodel) are being updating is breaking my computation graph...

renesax14 commented 4 years ago

I will be making comments to record/track my progress in the debugging as I learn stuff about the bug.

renesax14 commented 4 years ago

I tried printing the parameters in the update function _update_patched_params and displaying the computation graph to see if higher is breaking the computation graph. It seems that function is not break it as shown the two pics (node I've removed biases to make graphs simpler):

1)

param_sum

2)

param_sum1

Code:

def _update_patched_params(
    fmodule: _MonkeyPatchBase,
    params_box: _typing.Sequence[_typing.List[_torch.Tensor]],
    params_offset: int
) -> int:
    num_params = len([1 for p in fmodule._parameters.values() if p is not None])
    child_params_offset = params_offset + num_params
    for name, child in fmodule._modules.items():
        child_params_offset = _update_patched_params(
            child, params_box, child_params_offset
        )
    #p_tot = 0
    for name, param in zip(fmodule._param_names,params_box[0][params_offset:params_offset + num_params]):
        #delattr(fmodule, name)
        setattr(fmodule, name, param)
        make_dot(param.sum()).render(
            filename='param_sum1',
            format='png'
        )
        #st()
        #print(name)
        #p_tot += param.sum()
    #st()
    #print(f'> p_tot = {p_tot}')
    #p_tot.backward()
    #print(f'===> hidden.grad = {hidden.grad}')
    #print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
    return child_params_offset
renesax14 commented 4 years ago

My code should work even if the original model or fmodel is not trainable:

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
    child_model.conv1.weight.requires_grad = False
    child_model.fc.weight.requires_grad = False

although this is not super important...

renesax14 commented 4 years ago

Seems that it does not backpropage to my neural net step size (named eta) if I try to call backwards inside step (but it does inside my implementation of _update).

Code and output:

    def step(
        self,
        loss: _torch.Tensor,
        params: _typing.Iterable[_torch.Tensor] = None,
        override: _typing.Optional[_OverrideType] = None,
        grad_callback: _typing.Optional[_GradCallbackType] = None,
        eta=None,
        **kwargs
    ) -> _typing.Iterable[_torch.Tensor]:
        r"""Perform a model update.

        This would be used by replacing the normal sequence::

            opt.zero_grad()
            loss.backward()
            opt.step()

        with::

            diffopt.step(loss)

        Args:
            loss: the loss tensor.
            params (optional): the parameters with regard to which we measure
                the loss. These must be provided if the differentiable optimizer
                did not receive a patched model with a view over its own fast
                weights at initialisation. If there is such a model, and params
                are provided, they will overwrite the params of the encapsulated
                model.
            override (optional): a dictionary mapping optimizer settings (i.e.
                those which would be passed to the optimizer constructor or
                provided within parameter groups) to either singleton lists of
                override values, or to a list of override values of length equal
                to the number of parameter groups. If a single override is
                provided for a keyword, it is used for all parameter groups. If
                a list is provided, the ``i``\ th element of the list overrides
                the corresponding setting in the ``i``\ th parameter group. This
                permits the passing of tensors requiring gradient to
                differentiable optimizers for use as optimizer settings. Setting
                override here has highest precedence, i.e. it will override any
                tensors provided as override during the creation of the
                differentiable optimizer, where there is name clash.
            grad_callback: (optional) a single argument function which will be
                applied to a list of gradients of parameters, which respects the
                order specified by ``reference_params``. This can be used to
                apply a function, such as gradient clipping, to all (or a
                subset) of these gradients every time the step function is
                called. This callback overrides the default provided when
                constructing the differentiable optimizer.

        Returns:
            The updated parameters, which will individually have ``grad_fn``\ s
            of their own. If the optimizer has an encapsulated patched model,
            its view over its own fast weights will be updated with these
            params.
        """
        print('---------> IN .step(loss)')
        #st()
        #eta = eta[0]
        # Deal with override
        if override is not None:
            self._apply_override(override)

        if self._fmodel is None or self._fmodel.fast_params is None:
            if params is None:
                raise ValueError(
                    "params kwarg must be passed to step if the differentiable "
                    "optimizer doesn't have a view on a patched model with "
                    "params."
                )
        else:
            params = self._fmodel.fast_params if params is None else params

        params = list(params)

        # This allows us to gracefully deal with cases where params are frozen.
        grad_targets = [
            p if p.requires_grad else _torch.tensor([], requires_grad=True)
            for p in params
        ]

        all_grads = _torch.autograd.grad(
            loss,
            grad_targets,
            create_graph=self._track_higher_grads,
            allow_unused=True  # boo
        )

        if grad_callback is not None:
            all_grads = grad_callback(all_grads)
        elif self._grad_callback is not None:
            all_grads = self._grad_callback(all_grads)

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        self._update(grouped_grads)

        new_params = params[:]
        print(f'self._track_higher_grads = {self._track_higher_grads}')
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            for p, index in zip(group['params'], mapping):
                if self._track_higher_grads:
                    new_params[index] = p
                else:
                    new_params[index] = p.detach().requires_grad_()

        p_tot = 0
        for p in new_params:
            p_tot += p.sum()
        p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        print(f'eta.fc.grad = {eta.fc.weight.grad}')
        return new_params

Output:

---------> IN .step(loss)
self._track_higher_grads = True
eta.fc.grad = None
renesax14 commented 4 years ago

it seems it doesn't work either right outside of _update(grouped_grads):

code and output:

    def step(
        self,
        loss: _torch.Tensor,
        params: _typing.Iterable[_torch.Tensor] = None,
        override: _typing.Optional[_OverrideType] = None,
        grad_callback: _typing.Optional[_GradCallbackType] = None,
        eta=None,
        **kwargs
    ) -> _typing.Iterable[_torch.Tensor]:
        r"""Perform a model update.

        This would be used by replacing the normal sequence::

            opt.zero_grad()
            loss.backward()
            opt.step()

        with::

            diffopt.step(loss)

        Args:
            loss: the loss tensor.
            params (optional): the parameters with regard to which we measure
                the loss. These must be provided if the differentiable optimizer
                did not receive a patched model with a view over its own fast
                weights at initialisation. If there is such a model, and params
                are provided, they will overwrite the params of the encapsulated
                model.
            override (optional): a dictionary mapping optimizer settings (i.e.
                those which would be passed to the optimizer constructor or
                provided within parameter groups) to either singleton lists of
                override values, or to a list of override values of length equal
                to the number of parameter groups. If a single override is
                provided for a keyword, it is used for all parameter groups. If
                a list is provided, the ``i``\ th element of the list overrides
                the corresponding setting in the ``i``\ th parameter group. This
                permits the passing of tensors requiring gradient to
                differentiable optimizers for use as optimizer settings. Setting
                override here has highest precedence, i.e. it will override any
                tensors provided as override during the creation of the
                differentiable optimizer, where there is name clash.
            grad_callback: (optional) a single argument function which will be
                applied to a list of gradients of parameters, which respects the
                order specified by ``reference_params``. This can be used to
                apply a function, such as gradient clipping, to all (or a
                subset) of these gradients every time the step function is
                called. This callback overrides the default provided when
                constructing the differentiable optimizer.

        Returns:
            The updated parameters, which will individually have ``grad_fn``\ s
            of their own. If the optimizer has an encapsulated patched model,
            its view over its own fast weights will be updated with these
            params.
        """
        print('---------> IN .step(loss)')
        #st()
        #eta = eta[0]
        # Deal with override
        if override is not None:
            self._apply_override(override)

        if self._fmodel is None or self._fmodel.fast_params is None:
            if params is None:
                raise ValueError(
                    "params kwarg must be passed to step if the differentiable "
                    "optimizer doesn't have a view on a patched model with "
                    "params."
                )
        else:
            params = self._fmodel.fast_params if params is None else params

        params = list(params)

        # This allows us to gracefully deal with cases where params are frozen.
        grad_targets = [
            p if p.requires_grad else _torch.tensor([], requires_grad=True)
            for p in params
        ]

        all_grads = _torch.autograd.grad(
            loss,
            grad_targets,
            create_graph=self._track_higher_grads,
            allow_unused=True  # boo
        )

        if grad_callback is not None:
            all_grads = grad_callback(all_grads)
        elif self._grad_callback is not None:
            all_grads = self._grad_callback(all_grads)

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        self._update(grouped_grads)

        p_tot = 0
        for p in params[:]:
            p_tot += p.sum()
        p_tot.backward()

        # new_params = params[:]
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        # p_tot = 0
        # for p in new_params:
        #     p_tot += p.sum()
        # p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        print(f'eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params

output:

> p_tot = 3.033494472503662
---------> IN .step(loss)
eta.fc.grad = None
renesax14 commented 4 years ago

I tried checking if the gradients of the learnable nn step size eta inside my custom _update did populate the gradients and it does:

output:

---------> IN .step(loss)
==> eta.fc.weight.grad = tensor([[-0.0260]])

code:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
renesax14 commented 4 years ago

I tried returning the new parameters from the _update function inside step and then compute the backward pass but it did not work. This is the code that failed:

        new_params = self._update(grouped_grads)

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        # p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        p_tot = 0
        for p in new_params:
            p_tot += p.sum()
        p_tot.backward()

my _update:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        new_params = group['params']
        return new_params

My next attempt is to append the params to a virgin list and bypass any grouped pytorch thing because my hypothesis is that pytorch might be doing stuff under the hood somewhere.

renesax14 commented 4 years ago

Appending the new parameters myself and bypassing groups list did not work:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        return new_params

step is:

       new_params = self._update(grouped_grads)

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        # p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        p_tot = 0
        for p in new_params:
            p_tot += p.sum()
        p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        print(f'eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params

output:

---------> IN .step(loss)
eta.fc.grad = None
renesax14 commented 4 years ago

Well I thought that now if we returned the value p_tot directly and called backwards outside of step would populate the gradients of eta because doing that inside of _update works. So I did that and it still didn't populate the gradients despite populating them inside _update.

Code:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        #return new_params
        return p_tot

code of _step:

        #new_params = self._update(grouped_grads)
        p_tot = self._update(grouped_grads)

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        # p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        # p_tot = 0
        # for p in new_params:
        #     p_tot += p.sum()
        p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        print(f'p_tot = {p_tot}')
        print(p_tot)
        print(f'eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params

as we can see p_tot does have a grad_fn function which makes this bug really mysterious to me.

renesax14 commented 4 years ago

Ok so the last thing that occurred to me is to print the computation graphs inside _update and right outside it (inside of step). I expected the graphs to be different but they are exactly the same. Which puzzles me even more:

p_tot_inside_update Inside out update:

p_tot_inside_update

p_tot_inside_step Outside out update (inside step):

p_tot_inside_step


code for reference:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        #return new_params
        make_dot(p_tot).render('p_tot_inside_update', format='png')
        return p_tot 

outside:

        #new_params = self._update(grouped_grads)
        p_tot = self._update(grouped_grads)
        make_dot(p_tot).render('p_tot_inside_step', format='png')
renesax14 commented 4 years ago

I made my own step function and commented out nearly everything and the gradients for eta are still not populated.

My suspicion is that the error might be here:


        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                #group['params'][i] = params[index].T
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

becauses params seems to contain nn.Parameters and that have caused me issue in the past.


Code for my my_step and my _update:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        #return new_params
        #make_dot(p_tot).render('p_tot_inside_update', format='png')
        return p_tot, new_params

    def my_step(
        self,
        loss,
        params = None,
        override = None,
        grad_callback = None,
        eta=None,
        **kwargs
    ):
        print('---------> IN MY .step(loss)')
        #st()
        #eta = eta[0]
        # Deal with override
        # if override is not None:
        #     self._apply_override(override)

        # if self._fmodel is None or self._fmodel.fast_params is None:
        #     if params is None:
        #         raise ValueError(
        #             "params kwarg must be passed to step if the differentiable "
        #             "optimizer doesn't have a view on a patched model with "
        #             "params."
        #         )
        # else:
        #     params = self._fmodel.fast_params if params is None else params

        #params = self._fmodel.fast_params if params is None else params
        params = self._fmodel.fast_params

        params = list(params)

        # This allows us to gracefully deal with cases where params are frozen.
        # grad_targets = [
        #     p if p.requires_grad else torch.tensor([], requires_grad=True)
        #     for p in params
        # ]
        grad_targets = params

        # all_grads = torch.autograd.grad(
        #     loss,
        #     grad_targets,
        #     create_graph=self._track_higher_grads,
        #     allow_unused=True  # boo
        # )

        all_grads = torch.autograd.grad(
            loss,
            grad_targets
        )

        # if grad_callback is not None:
        #     all_grads = grad_callback(all_grads)
        # elif self._grad_callback is not None:
        #     all_grads = self._grad_callback(all_grads)

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                #group['params'][i] = params[index].T
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        #new_params = self._update(grouped_grads)
        p_tot, new_params = self._update(grouped_grads)
        #make_dot(p_tot).render('p_tot_inside_step', format='png')

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        #p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        # p_tot = 0
        # for p in new_params:
        #     p_tot += p.sum()
        #p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        print(f'p_tot = {p_tot}')
        print(p_tot)
        print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params

output:

---------> IN MY .step(loss)
p_tot = -1.2309236526489258
tensor(-1.2309, grad_fn=<AddBackward0>)
+++>>> eta.fc.grad = None
renesax14 commented 4 years ago

Ok some progress, I was able to have the gradients be non-zero inside my custom step function by indexing the self.param_groups[0][trainable_opt_param] directly AND updating fmodel inside my _update function. My suspicion is that self.param_groups is being deepcopied somewhere without my permission

Code:

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #
        self._fmodel.update_params(new_params)
        #x = torch.randn([4,3,32,32])
        #y = self._fmodel(x)
        #y.sum().backward()
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        #return new_params
        #make_dot(p_tot).render('p_tot_inside_update', format='png')
        return p_tot, new_params

    def my_step(
        self,
        loss,
        params = None,
        override = None,
        grad_callback = None,
        eta=None,
        **kwargs
    ):
        print('---------> IN MY .step(loss)')
        eta = self.param_groups[0]['eta']
        hidden = self.param_groups[0]['hidden']
        #st()
        #eta = eta[0]
        # Deal with override
        # if override is not None:
        #     self._apply_override(override)

        # if self._fmodel is None or self._fmodel.fast_params is None:
        #     if params is None:
        #         raise ValueError(
        #             "params kwarg must be passed to step if the differentiable "
        #             "optimizer doesn't have a view on a patched model with "
        #             "params."
        #         )
        # else:
        #     params = self._fmodel.fast_params if params is None else params

        #params = self._fmodel.fast_params if params is None else params
        params = self._fmodel.fast_params

        params = list(params)

        # This allows us to gracefully deal with cases where params are frozen.
        # grad_targets = [
        #     p if p.requires_grad else torch.tensor([], requires_grad=True)
        #     for p in params
        # ]
        grad_targets = params

        # all_grads = torch.autograd.grad(
        #     loss,
        #     grad_targets,
        #     create_graph=self._track_higher_grads,
        #     allow_unused=True  # boo
        # )

        all_grads = torch.autograd.grad(
            loss,
            grad_targets
        )

        # if grad_callback is not None:
        #     all_grads = grad_callback(all_grads)
        # elif self._grad_callback is not None:
        #     all_grads = self._grad_callback(all_grads)

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                #group['params'][i] = params[index].T
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        #new_params = self._update(grouped_grads)
        p_tot, new_params = self._update(grouped_grads)
        #make_dot(p_tot).render('p_tot_inside_step', format='png')

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        #p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        # p_tot = 0
        # for p in new_params:
        #     p_tot += p.sum()
        #p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        #self._fmodel.update_params(new_params)
        x = torch.randn([4,3,32,32])
        y = self._fmodel(x)
        y.sum().backward()

        #print(f'p_tot = {p_tot}')
        #print(p_tot)
        print(f'+++>>> hidden.grad = {hidden.grad}')
        print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params
        st()
        return

output:

----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
+++>>> hidden.grad = tensor([[0.0102]])
+++>>> eta.fc.grad = tensor([[-0.0225]])
renesax14 commented 4 years ago

Ok so it seems that I can see the gradients only if I index my params from .param_groups but my original models somehow have been detached or deep copied or something...disassociated from the original definition...

    nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
        nb_inner_steps = 5
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
        #with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                meta_opt.zero_grad()
                if inner_i >= nb_inner_steps:
                    break
                logits = fmodel(inner_inputs)
                print(type(fmodel))
                print(fmodel)
                inner_loss = criterion(logits, inner_targets)
                print(f'--> inner_i = {inner_i}')
                print(f'inner_loss^<{inner_i}>: {inner_loss}')
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
                p_tot = sum([ p.sum() for p in fmodel.parameters() ])
                #print(f'> p_tot = {p_tot}')

                print('\n----- DEBUGGING print statements after this line ----')
                #diffopt.step(inner_loss,eta=eta)
                #diffopt.my_step(inner_loss, eta=eta)
                fmodel = diffopt.my_step(inner_loss)
                #step(diffopt, inner_loss,eta=eta)
                x = torch.randn([4,3,32,32])
                y = fmodel(x)
                #y = diffopt._fmodel(x)
                y.sum().backward()
                #sys.exit()
                #new_params = diffopt.step(inner_loss, eta=eta) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
                #p = list(fmodel.parameters())[1]
                #params = list(child_model.parameters())+list(eta.parameters())+list([hidden])
                #c_params = list(child_model.parameters())
                #params = {'child_model[0]':c_params[0], 'child_model[1]':c_params[1], 'eta_params':eta.fc.weight, 'hidden':hidden }
                # make_dot(p.sum(),params=params).render(
                #     filename='param_sum_inner_loop',
                #     format='png'
                # )
                # make_dot(p.sum()).render(
                #     filename='param_sum_inner_loop_no_names_5',
                #     format='png'
                # )
                ##
                #p_tot_new = sum([ p.sum() for p in new_params ])
                #p_tot = sum([ p.sum() for p in fmodel.parameters() ])
                #print(f'> p_tot = {p_tot}')
                #print(f'same?: {p_tot_new == p_tot}')
                #p_tot.backward()
                #p_tot.backward()
                #p.sum().backward()
                eta = diffopt.param_groups[0]['eta']
                hidden = diffopt.param_groups[0]['hidden']
                print(f'===> hidden.grad = {hidden.grad}')
                print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
                #print()
                st()

works:

----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
lr^<0> = tensor([0.0389], grad_fn=<MulBackward0>)
===> hidden.grad = tensor([[0.0011]])
===> eta.fc.weight.grad = tensor([[-0.0016]])
renesax14 commented 4 years ago

Ok so the issue is definitively some sort of deep copy. Code inside my inner loop (inside the context manager):

                eta2 = diffopt.param_groups[0]['eta']
                hidden2 = diffopt.param_groups[0]['hidden']
                h = hidden
                print(f'hidden is hidden2 = {hidden is hidden2}')
                print(f'hidden is h = {hidden is h}')

output:

----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
lr^<0> = tensor([0.0450], grad_fn=<MulBackward0>)
hidden is hidden2 = False
hidden is h = True
===> hidden.grad = None
===> eta.fc.weight.grad = None

Current fix is to only index the optimizer's parameters from the diffopt.param_groups AND update the fmodel inside your own custom _update function. Somehow outside of that, it not longer points to the original parameters...

renesax14 commented 4 years ago

Ok indeed that does fix it. So my current solution is:

1) update the parameters of the trainable optimizer inside your own _update function 2) inside the context manager and inner loop assign your optimizer variables each time so that they don't get lost:

                eta = diffopt.param_groups[0]['eta']
                hidden = diffopt.param_groups[0]['hidden']

@egrefen when you have time it would be nice if you take a look at this because I am afraid there might be some subtle thing I might have missed. But at the very least the gradients of my learning rate are not being populated.

renesax14 commented 4 years ago

Current code that seems to work (unsure if there is some subtle bug I might not know about):

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, eta, prev_lr, hidden, meta_opt):
        defaults = {'eta':eta, 'prev_lr':prev_lr, 'hidden':hidden, 'meta_opt':meta_opt}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        meta_opt = self.param_groups[0]['meta_opt']
        hidden = self.param_groups[0]['hidden']
        prev_lr = self.param_groups[0]['prev_lr']
        eta = self.param_groups[0]['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.1*eta(prev_lr).view(1)
        p_tot = 0
        new_params = []
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                #group['params'][p_idx] = _add(p, -group['lr'], g)
                p_new = p + lr*g
                group['params'][p_idx] = p_new
                p_tot += p_new.sum()
                #make_dot(p_new.sum()).render('p_new',format='png')
                #print()
                new_params.append( p_new )
        # fake returns
        self.param_groups[0]['prev_lr'] = lr
        #
        self._fmodel.update_params(new_params)
        #x = torch.randn([4,3,32,32])
        #y = self._fmodel(x)
        #y.sum().backward()
        #p_tot.backward()
        #print(f'==> hidden.grad = {hidden.grad}')
        #print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #meta_opt.zero_grad()
        # print(f'==> hidden.grad = {hidden.grad}')
        # print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
        #print()
        #new_params = group['params']
        #return new_params
        #make_dot(p_tot).render('p_tot_inside_update', format='png')
        return p_tot, new_params

    def my_step(
        self,
        loss,
        params = None,
        override = None,
        grad_callback = None,
        eta=None,
        **kwargs
    ):
        print('---------> IN MY .step(loss)')
        eta = self.param_groups[0]['eta']
        hidden = self.param_groups[0]['hidden']
        #st()
        #eta = eta[0]
        # Deal with override
        # if override is not None:
        #     self._apply_override(override)

        # if self._fmodel is None or self._fmodel.fast_params is None:
        #     if params is None:
        #         raise ValueError(
        #             "params kwarg must be passed to step if the differentiable "
        #             "optimizer doesn't have a view on a patched model with "
        #             "params."
        #         )
        # else:
        #     params = self._fmodel.fast_params if params is None else params

        #params = self._fmodel.fast_params if params is None else params
        params = self._fmodel.fast_params

        params = list(params)

        # This allows us to gracefully deal with cases where params are frozen.
        # grad_targets = [
        #     p if p.requires_grad else torch.tensor([], requires_grad=True)
        #     for p in params
        # ]
        grad_targets = params

        # all_grads = torch.autograd.grad(
        #     loss,
        #     grad_targets,
        #     create_graph=self._track_higher_grads,
        #     allow_unused=True  # boo
        # )

        all_grads = torch.autograd.grad(
            loss,
            grad_targets
        )

        # if grad_callback is not None:
        #     all_grads = grad_callback(all_grads)
        # elif self._grad_callback is not None:
        #     all_grads = self._grad_callback(all_grads)

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                #group['params'][i] = params[index].T
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        #new_params = self._update(grouped_grads)
        p_tot, new_params = self._update(grouped_grads)
        #make_dot(p_tot).render('p_tot_inside_step', format='png')

        # p_tot = 0
        # for p in params[:]:
        #     p_tot += p.sum()
        #p_tot.backward()

        #new_params = params[:]
        #st()
        # print(f'self._track_higher_grads = {self._track_higher_grads}')
        # for group, mapping in zip(self.param_groups, self._group_to_param_list):
        #     for p, index in zip(group['params'], mapping):
        #         if self._track_higher_grads:
        #             new_params[index] = p
        #         else:
        #             new_params[index] = p.detach().requires_grad_()

        # p_tot = 0
        # for p in new_params:
        #     p_tot += p.sum()
        #p_tot.backward()

        # if self._fmodel is not None:
        #     self._fmodel.update_params(new_params)
        #     print()
        #     st()
        #     set_attr(self._fmodel, names, val)
        #     del_attr(self._fmodel, names)

        #self._fmodel.update_params(new_params)
        # x = torch.randn([4,3,32,32])
        # y = self._fmodel(x)
        # y.sum().backward()

        #print(f'p_tot = {p_tot}')
        #print(p_tot)
        # print(f'+++>>> hidden.grad = {hidden.grad}')
        # print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
        #return new_params
        #st()
        return self._fmodel

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
    #child_model.conv1.weight.requires_grad = False
    #child_model.fc.weight.requires_grad = False

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))
    #meta_params = itertools.chain(child_model.parameters(),eta.parameters(),[hidden])
    meta_params = itertools.chain(eta.parameters(),[hidden])
    meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
    inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden, hidden=hidden, meta_opt=meta_opt)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    print()
    nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
        nb_inner_steps = 5
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
        #with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                meta_opt.zero_grad()
                if inner_i >= nb_inner_steps:
                    break
                logits = fmodel(inner_inputs)
                print(type(fmodel))
                print(fmodel)
                inner_loss = criterion(logits, inner_targets)
                print(f'--> inner_i = {inner_i}')
                print(f'inner_loss^<{inner_i}>: {inner_loss}')
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
                p_tot = sum([ p.sum() for p in fmodel.parameters() ])
                #print(f'> p_tot = {p_tot}')

                print('\n----- DEBUGGING print statements after this line ----')
                #diffopt.step(inner_loss,eta=eta)
                #diffopt.my_step(inner_loss, eta=eta)
                #fmodel = diffopt.my_step(inner_loss)
                diffopt.my_step(inner_loss)
                #step(diffopt, inner_loss,eta=eta)
                #x = torch.randn([4,3,32,32])
                #y = fmodel(x)
                #y = diffopt._fmodel(x)
                #y.sum().backward()
                #sys.exit()
                #new_params = diffopt.step(inner_loss, eta=eta) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
                #p = list(fmodel.parameters())[1]
                #params = list(child_model.parameters())+list(eta.parameters())+list([hidden])
                #c_params = list(child_model.parameters())
                #params = {'child_model[0]':c_params[0], 'child_model[1]':c_params[1], 'eta_params':eta.fc.weight, 'hidden':hidden }
                # make_dot(p.sum(),params=params).render(
                #     filename='param_sum_inner_loop',
                #     format='png'
                # )
                # make_dot(p.sum()).render(
                #     filename='param_sum_inner_loop_no_names_5',
                #     format='png'
                # )
                ##
                #p_tot_new = sum([ p.sum() for p in new_params ])
                #p_tot = sum([ p.sum() for p in fmodel.parameters() ])
                #print(f'> p_tot = {p_tot}')
                #print(f'same?: {p_tot_new == p_tot}')
                #p_tot.backward()
                #p_tot.backward()
                #p.sum().backward()
                eta = diffopt.param_groups[0]['eta']
                hidden = diffopt.param_groups[0]['hidden']
                #h = hidden
                #print(f'hidden is hidden2 = {hidden is hidden2}')
                #print(f'hidden is h = {hidden is h}')
                print(f'===> hidden.grad = {hidden.grad}')
                print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
                #print()
                #st()
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print()
            #p_tot = sum([ p.sum() for p in fmodel.parameters() ])
            #p_tot.backward()
            #print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            #meta_loss = meta_loss + inner_loss
            #make_dot(meta_loss).render('meta_loss',format='png')
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            print(f'eta.fc.weight = {eta.fc.weight.grad}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )

if __name__ == "__main__":
    main()
    print('---> Done\a')
renesax14 commented 4 years ago

Commenting out this line https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/optim.py#L101

removes the bug where my meta-parameters (learning rate) used to not being updated but now introduced new issues...why does it have to be a deep copy?

renesax14 commented 4 years ago

notice that the fact things are deep copied means if those parameters are given to the meta-optimizer, then the meta-optimizer will not update the right version of the optimization nn's (in fact they won't have gradients since they aren't even pointing to the same object...).

Trying to find a work around, seems that re-assigning param_groups list does the trick...though a less messy solution would be nice.

egrefen commented 4 years ago

Thanks for the detailed comments. You've given me a lot to look at so it's probably very helpful. I would encourage you to put stack traces in comments when you see crashes too as this helps debug things.

Just to manage expectations: I'm currently the only person supporting this project internally and have other responsibilities so cannot guarantee I can resolve this immediately (as there's a more pressing issue with second order backprop being broken for higher with pytorch v1.4 and also some memory leak issues). But I'll get to it as soon as I can!

renesax14 commented 4 years ago

Thanks for the detailed comments. You've given me a lot to look at so it's probably very helpful. I would encourage you to put stack traces in comments when you see crashes too as this helps debug things.

Just to manage expectations: I'm currently the only person supporting this project internally and have other responsibilities so cannot guarantee I can resolve this immediately (as there's a more pressing issue with second order backprop being broken for higher with pytorch v1.4 and also some memory leak issues). But I'll get to it as soon as I can!

No worries, I understand. I think I got a temporary fix that will work for what I want to do.

I am happy to answer questions if you need help.

I will try to summarize the issue so that it's easier for you to go through this once you decide to fix the issue (and post my temporary solution as it could provide how I avoided the issue).

egrefen commented 4 years ago

That would be super helpful. Thanks for understanding, and for all the detailed comments.

renesax14 commented 4 years ago

Part of my current solution is to define this function (for updating the meta-optimizer so that it gets the new copies of the new parameters higher made a deep copy of):

    def load_new_params(self, params):
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)

but it's a copy paste of the original init method...which also has something called self.state which I am unsure what it's for or if it breaks my optimizer in some unknown way...


Perhaps I should create a new param_groups list and load it with:

self.__setstate__({'state': state, 'param_groups': param_groups})

related:

renesax14 commented 4 years ago

Ok, it seems this works for me for now (unless there is a subtle bug from higher or optim that I am not aware of...e.g. how self.state works or something else):


import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

def load_new_params(optimizer, params):
    optimizer.param_groups = []

    param_groups = list(params)
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    if not isinstance(param_groups[0], dict):
        param_groups = [{'params': param_groups}]
    for param_group in param_groups:
        optimizer.add_param_group(param_group)

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
        # update model
        new_params = self.param_groups[0]['params'] 
        new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
        self._fmodel.update_params(new_params)

    def my_step(
        self,
        loss,
        params = None,
        override = None,
        grad_callback = None,
        eta=None,
        **kwargs
    ):
        params = self._fmodel.fast_params

        params = list(params)

        grad_targets = params

        all_grads = torch.autograd.grad(
            loss,
            grad_targets,
            create_graph=self._track_higher_grads,
            allow_unused=True  # boo
        )

        grouped_grads = []
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            grads = []
            for i, index in enumerate(mapping):
                #group['params'][i] = params[index].T
                group['params'][i] = params[index]
                grads.append(all_grads[index])
            grouped_grads.append(grads)

        self._update(grouped_grads)

        # WARNING DON'T UPDATE PARAMETERS IN STEP
        return self._fmodel

    def _track_higher_grads_for_new_params(self, new_params, track_higher_grads):
        '''
        For the new params, set if we are tracking higher order grads for them or detaching them for the computation graph.
        '''
        for group, mapping in zip(self.param_groups, self._group_to_param_list):
            for p, index in zip(group['params'], mapping):
                if track_higher_grads:
                    new_params[index] = p
                else:
                    new_params[index] = p.detach().requires_grad_()
        return new_params

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))

    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    nb_outer_steps = 3 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # do inner-training: ~ argmin L^train(theta)
        nb_inner_steps = 3
        trainable_opt_params = {'eta':eta, 'hidden':hidden}
        trainable_opt_state = {'prev_lr':hidden}
        inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        print('==== Inner Loop ====')
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                meta_opt.zero_grad()
                if inner_i >= nb_inner_steps:
                    break                
                print(f'-> inner_i = {inner_i}')
                print(f'hidden^<{inner_i}> = {hidden}')
                #print(f'eta.fc.weight^<{inner_i}> = {eta.fc.weight}')
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                diffopt.my_step(inner_loss)
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                eta = diffopt.param_groups[0]['trainable_opt_params']['eta']
                hidden = diffopt.param_groups[0]['trainable_opt_params']['hidden']
                print(f'hidden^<{inner_i}> = {hidden}')
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print('\n---- Outer loop print statements ----')
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            assert hidden.grad is not None 
            print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
            print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            # param_groups = meta_opt.param_groups
            # print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
            # print(f'hidden is param_groups[0][params][0] : {hidden is param_groups[0]["params"][0]}')
            # print(f'mutate meta_opt')
            #param_groups[0]['params'] = [hidden]
            new_meta_params = []
            new_meta_params.append( {'params': hidden, 'lr':lr} )
            new_meta_params.append( {'params': eta.parameters(), 'lr':lr} )
            # print(f'new_meta_params = {new_meta_params}')
            # print(f'param_groups = {param_groups}')
            load_new_params(meta_opt, new_meta_params)
            # print(f'meta_opt.param_groups = {meta_opt.param_groups}')
            # print(meta_opt.param_groups is param_groups)
            # print(param_groups[0]['params'][0] is meta_opt.param_groups[0]['params'][0])
            # print(hidden is param_groups[0]['params'][0])
            # print(hidden is meta_opt.param_groups[0]['params'][0])
            # param_groups[1]['params'] = eta.fc.weight
            # print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
            # print(f'hidden is param_groups[0][params][0] : {hidden is param_groups[0]["params"][0]}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print(f'> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
            print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            # print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
            # print(f'hidden is param_groups[0][params] : {hidden is param_groups[0]["params"][0]}')
            #print(f'> eta.fc.weight^<{outer_i}> = {meta_opt.param_groups[1]["params"].T}')
            print()

if __name__ == "__main__":
    main()
    print('---> Done\a')
renesax14 commented 4 years ago

Ok nearly got it to work. I removed the deep copy and the context manager but now it's complaining that I'm trying to call backwards twice on the same computation graph...

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

def load_new_params(optimizer, params):
    optimizer.param_groups = []

    param_groups = list(params)
    if len(param_groups) == 0:
        raise ValueError("optimizer got an empty parameter list")
    if not isinstance(param_groups[0], dict):
        param_groups = [{'params': param_groups}]
    for param_group in param_groups:
        optimizer.add_param_group(param_group)

def reload_param_groups(opt, params):
    if isinstance(params, torch.Tensor):
        raise TypeError("params argument given to the optimizer should be "
                        "an iterable of Tensors or dicts, but got " +
                        torch.typename(params))
    # replace params
    params = list(params)
    if isinstance(params[0], dict):
        raise ValueError(f'The hacked higher version does not support proper pytorch grouped params yet.')
    opt.param_groups[0]['params'] = params
    # opt.param_groups = []

    # param_groups = list(params)
    # if len(param_groups) == 0:
    #     raise ValueError("optimizer got an empty parameter list")
    # if not isinstance(param_groups[0], dict):
    #     param_groups = [{'params': param_groups}]

    # for param_group in param_groups:
    #     opt.add_param_group(param_group)

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
        # update model
        # new_params = self.param_groups[0]['params'] 
        # new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
        # self._fmodel.update_params(new_params)

    # def my_step2(
    #     self,
    #     loss,
    #     params = None,
    #     override = None,
    #     grad_callback = None,
    #     eta=None,
    #     **kwargs
    # ):
    #     # Deal with override
    #     if override is not None:
    #         self._apply_override(override)

    #     if self._fmodel is None or self._fmodel.fast_params is None:
    #         if params is None:
    #             raise ValueError(
    #                 "params kwarg must be passed to step if the differentiable "
    #                 "optimizer doesn't have a view on a patched model with "
    #                 "params."
    #             )
    #     else:
    #         params = self._fmodel.fast_params if params is None else params

    #     params = list(params)

    #     # This allows us to gracefully deal with cases where params are frozen.
    #     grad_targets = [
    #         p if p.requires_grad else torch.tensor([], requires_grad=True)
    #         for p in params
    #     ]

    #     all_grads = torch.autograd.grad(
    #         loss,
    #         grad_targets,
    #         create_graph=self._track_higher_grads,
    #         allow_unused=True  # boo
    #     )

    #     if grad_callback is not None:
    #         all_grads = grad_callback(all_grads)
    #     elif self._grad_callback is not None:
    #         all_grads = self._grad_callback(all_grads)

    #     grouped_grads = []
    #     for group, mapping in zip(self.param_groups, self._group_to_param_list):
    #         grads = []
    #         for i, index in enumerate(mapping):
    #             group['params'][i] = params[index]
    #             grads.append(all_grads[index])
    #         grouped_grads.append(grads)

    #     self._update(grouped_grads)

    #     # ---> WARNING DON'T UPDATE PARAMETERS IN STEP <---
    #     # the code bellow is now done inside of your _update function
    #     # new_params = params[:]
    #     # for group, mapping in zip(self.param_groups, self._group_to_param_list):
    #     #     for p, index in zip(group['params'], mapping):
    #     #         if self._track_higher_grads:
    #     #             new_params[index] = p
    #     #         else:
    #     #             new_params[index] = p.detach().requires_grad_()

    #     # if self._fmodel is not None:
    #     #     self._fmodel.update_params(new_params)
    #     return self._fmodel

    # def _track_higher_grads_for_new_params(self, new_params, track_higher_grads):
    #     '''
    #     For the new params, set if we are tracking higher order grads for them or detaching them for the computation graph.
    #     '''
    #     for group, mapping in zip(self.param_groups, self._group_to_param_list):
    #         for p, index in zip(group['params'], mapping):
    #             if track_higher_grads:
    #                 new_params[index] = p
    #             else:
    #                 new_params[index] = p.detach().requires_grad_()
    #     return new_params

higher.register_optim(MySGD, TrainableSGD)

def main():
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")    
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()
    # get trainable opt params
    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))
    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    # get meta optimizer
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    #
    trainable_opt_params = {'eta':eta, 'hidden':hidden}
    trainable_opt_state = {'prev_lr':hidden}
    #inner_opt = MySGD(eta.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
    # diffopt = higher.optim.get_diff_optim(
    #     inner_opt,
    #     eta.parameters(), # for this hack it can be anything
    #     fmodel=None, # None
    #     device=device,
    #     override=None, # None default
    #     track_higher_grads=True # True default
    # )
    # do meta-training/ outerloop argmin L^val(theta)
    nb_outer_steps = 2 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        # sample child_model
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
        # do inner-training: ~ argmin L^train(psi)
        nb_inner_steps = 3   
        print('==== Inner Loop ====')
        fmodel = higher.patch.monkeypatch(
            child_model, 
            device, 
            copy_initial_weights=True # True default
        )
        inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        diffopt = higher.optim.get_diff_optim(
            inner_opt,
            child_model.parameters(), # for this hack it can be anything
            fmodel=fmodel, # None
            device=device,
            override=None, # None default
            track_higher_grads=True # True default
        )
        for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
            if inner_i >= nb_inner_steps:
                break
            print(f'-> outer_i = {outer_i}')                
            print(f'-> inner_i = {inner_i}')
            print(f'hidden^<{inner_i}> = {hidden}')
            logits = fmodel(inner_inputs)
            inner_loss = criterion(logits, inner_targets)
            print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            #child_model_params = [{'params':child_model.parameters()}]
            child_model_params = child_model.parameters()
            reload_param_groups(diffopt, child_model_params)
            diffopt._fmodel = fmodel
            diffopt.step(inner_loss)
            print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
            print(f'hidden^<{inner_i}> = {hidden}')
        # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        outer_outputs = fmodel(outer_inputs)
        meta_loss = criterion(outer_outputs, outer_targets) # L^val
        #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
        print('\n---- Outer loop print statements ----')
        print(f'----> outer_i = {outer_i}')
        print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
        #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
        meta_loss.backward()
        print(f'hidden.grad = {hidden.grad}')
        assert hidden.grad is not None 
        print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
        print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
        print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
        print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
        print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
        print()

if __name__ == "__main__":
    main()
    print('---> Done\a')

error:

Traceback (most recent call last):
  File "trainaible_step_no_deep_copy.py", line 305, in <module>
    main()
  File "trainaible_step_no_deep_copy.py", line 293, in main
    meta_loss.backward()
  File "/Users/rene/miniconda3/envs/automl-meta-learning/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/Users/rene/miniconda3/envs/automl-meta-learning/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
renesax14 commented 4 years ago

I tried deleting the output node meta_loss with:

del meta_loss

because I was told that removed the computation graph but it did not.

renesax14 commented 4 years ago

So the issue is that this line of code of higher

self.param_groups = _copy.deepcopy(other.param_groups)

breaks the trainable step size I am trying to build.

I tried uncommenting it before but my code was still breaking.

With a lot of exploration it seems that only when I re-instantiate/rebuild the inner optimizer + differentiable optimizer before every inner loop then the code works (I think...)

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required

import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD

import torchvision
import torchvision.transforms as transforms

from torchviz import make_dot

import copy

import itertools

import sys

from collections import OrderedDict

from pdb import set_trace as st

#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
    def forward(self, input):
        '''
        Note that input.size(0) is usually the batch size.
        So what it does is that given any input with input.size(0) # of batches,
        will flatten to be 1 * nb_elements.
        '''
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)

def get_cifar10():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                            shuffle=False, num_workers=2)
    return trainloader, testloader

class MySGD(Optimizer):

    def __init__(self, params, trainable_opt_params, trainable_opt_state):
        defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
        super().__init__(params, defaults)

class TrainableSGD(DifferentiableOptimizer):

    def _update(self, grouped_grads, **kwargs):
        prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
        eta = self.param_groups[0]['trainable_opt_params']['eta']
        # start differentiable & trainable update
        zipped = zip(self.param_groups, grouped_grads)
        lr = 0.01*eta(prev_lr).view(1)
        for group_idx, (group, grads) in enumerate(zipped):
            for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
                if g is None:
                    continue
                p_new = p - lr*g
                group['params'][p_idx] = p_new
        # fake returns
        self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr

higher.register_optim(MySGD, TrainableSGD)

def main():
    # get dataloaders
    trainloader, testloader = get_cifar10()
    criterion = nn.CrossEntropyLoss()

    hidden = torch.randn(size=(1,1),requires_grad=True)
    print(f'-> hidden = {hidden}')
    eta = nn.Sequential(OrderedDict([
        ('fc', nn.Linear(1,1,bias=False)),
        ('sigmoid', nn.Sigmoid())
    ]))

    lr = 0.01
    meta_params = []
    meta_params.append( {'params': hidden, 'lr':lr} )
    meta_params.append( {'params': eta.parameters(), 'lr':lr} )
    #meta_opt = torch.optim.SGD(meta_params)
    meta_opt = torch.optim.Adam(meta_params)
    # do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) 
    nb_outer_steps = 5 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
    for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
        meta_opt.zero_grad()
        if outer_i >= nb_outer_steps:
            break
        #
        child_model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
            ('relu1', nn.ReLU()),
            ('Flatten', Flatten()),
            ('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
        ]))
        # do inner-training: ~ argmin L^train(theta)
        nb_inner_steps = 3
        trainable_opt_params = {'eta':eta, 'hidden':hidden}
        trainable_opt_state = {'prev_lr':hidden}
        child_model_params = [{'params':child_model.parameters()}]
        inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
        print('==== Inner Loop ====')
        with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
                if inner_i >= nb_inner_steps:
                    break
                print(f'-> outer_i = {outer_i}')                
                print(f'-> inner_i = {inner_i}')
                print(f'hidden^<{inner_i}> = {hidden}')
                logits = fmodel(inner_inputs)
                inner_loss = criterion(logits, inner_targets)
                print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                diffopt.step(inner_loss)
                print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
                print(f'hidden^<{inner_i}> = {hidden}')
            # compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            outer_outputs = fmodel(outer_inputs)
            meta_loss = criterion(outer_outputs, outer_targets) # L^val
            meta_loss.backward()
            #grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
            print('\n---- Outer loop print statements ----')
            print(f'----> outer_i = {outer_i}')
            print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
            #print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
            print(f'hidden.grad = {hidden.grad}')
            assert hidden.grad is not None
            assert eta.fc.weight is not None
            print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
            print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
            print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
            print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
            print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
            print()

if __name__ == "__main__":
    main()
    print('---> Done\a')
egrefen commented 4 years ago

My main feedback from looking at your code is that you are not using the differentiable optimizers as intended. The way you should be doing it is:

  1. Write a non-differentiable version of TrainableSGD (imagine it is just this class with frozen parameters) which subclasses torch.optim.Optimizer.
  2. Write a DifferentiableOptimizer version of this class.
  3. Read the docstring for the override kwarg here: https://higher.readthedocs.io/en/latest/optim.html
  4. Use the override kwarg to pass your trainable parameters when constructing the differentiable optimizer (you can do this when using innerloop_ctx).

I'm sorry if this answer seems like "you're using it wrong", but you're using it wrong. Please take a close look at @denisyarats' linked source code for a working example.

If this mode of usage does not fit your needs, but that's another matter. If that's the case, please explain what you can't do with this way of doing things that you'd want to do. A significantly simpler minimal example than what's been provided so far might be helpful to this end.

In the meantime, closing this issue as it's not clear there's a bug underlying it.

renesax14 commented 4 years ago

@egrefen thanks for taking a look at my discussion.

So my issue can be resolved if this line of code is removed:

https://github.com/facebookresearch/higher/blob/c6a79805b6d51c0571c04467a53cb1dc4e3754e4/higher/optim.py#L101

and becomes:

self.param_groups = other.param_groups

is there a reason why things have to be deep copied? Can higher function without it?

egrefen commented 4 years ago

That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.

Use override, please.

egrefen commented 4 years ago

Override is a kwarg for differentiable optims (at creation, or step time, and you can also use it with the context manager) which allows you to use arbitrary tensors instead of values held in the optimizer state. For example, you could override the learning rate with a tensor which requires grad, which would allow you to unroll your loops, take gradient of the meta-loss with regard to the learning rate, and update this tensor.

See https://higher.readthedocs.io/en/latest/optim.html for details, https://github.com/facebookresearch/higher/issues/32#issuecomment-594466772 for a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for an example.

On Fri, Mar 6, 2020 at 5:26 AM brando90 notifications@github.com wrote:

Use override, please.

My apologies if this is a dense question, what does that mean?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/facebookresearch/higher/issues/32?email_source=notifications&email_token=AAIXXKW6VCVDPFU3ME5M323RGCCQ7A5CNFSM4KYFDLMKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEOADYAI#issuecomment-595606529, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAIXXKXOXR2JGVNSKAIMCT3RGCCQ7ANCNFSM4KYFDLMA .

renesax14 commented 4 years ago

Override is a kwarg for differentiable optims (at creation, or step time, and you can also use it with the context manager) which allows you to use arbitrary tensors instead of values held in the optimizer state. For example, you could override the learning rate with a tensor which requires grad, which would allow you to unroll your loops, take gradient of the meta-loss with regard to the learning rate, and update this tensor. See https://higher.readthedocs.io/en/latest/optim.html for details, #32 (comment) for a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for an example.

This is not what I want. I am not trying to train the learning rate. I'm trying to have the inner optimizer to be parametrized like for example, the way it's use in meta-lstm meta-learner https://openreview.net/pdf?id=rJY0-Kcll

I will provide a minimal example that makes it easier to help me.

renesax14 commented 4 years ago

That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.

can you explain what that means?