Closed renesax14 closed 4 years ago
@denisyarats wrote examples of optimizing the learning rates for the GIMLi paper. Pinging him here so that maybe we can integrate some of that code in the examples folder (if he has time).
If you mean learning the entire optimizer as a parametric function, as in e.g. learning to learn by gradient descent by gradient descent then that would make an excellent example to add to the examples folder. We would welcome a pull request doing this, but don't have the cycles to do it ourselves at the moment.
Closing this issue for now, but we always welcome pull requests providing new examples!
@denisyarats do you have an example of learning the learning rate that we could add to the examples in this library?
yes, I do have this example, you can find it here: https://github.com/denisyarats/densenet_cifar10
Feel free to integrate it into the examples folder of higher.
yes, I do have this example, you can find it here: https://github.com/denisyarats/densenet_cifar10
Feel free to integrate it into the examples folder of higher.
will check it out thanks!
@egrefen sorry for bothering you again but I thought I was so close but I still got an error. It thinks my step size NN is not in the graph but it is because of this line of code:
p_new = p + lr*g
group['params'][p_idx] = p_new
but somehow that is not enough to have gradients...
Full script self contained script:
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
from collections import OrderedDict
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, eta, prev_lr):
defaults = {'eta':eta, 'prev_lr':prev_lr}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['prev_lr'] = lr
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10) )
]))
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1)),
('sigmoid', nn.Sigmoid())
]))
inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden)
meta_params = itertools.chain(child_model.parameters(),eta.parameters())
#meta_params = itertools.chain(eta.parameters(),[hidden])
meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
nb_inner_steps = 3
#with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
if inner_i >= nb_inner_steps:
break
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'--> inner_i = {inner_i}')
print(f'inner_loss^<{inner_i}>: {inner_loss}')
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
print()
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
make_dot(meta_loss).render('meta_loss',format='png')
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
print(f'eta.fc.weight = {eta.fc.weight.grad}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
if __name__ == "__main__":
main()
print('---> Done\a')
notice the None's:
Files already downloaded and verifiedFiles already downloaded and verified
-> hidden = tensor([[0.8459]], requires_grad=True)
--> inner_i = 0
inner_loss^<0>: 2.2696359157562256
lr^<-1> = tensor([[0.8459]], requires_grad=True)
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
--> inner_i = 1
inner_loss^<1>: 2.0114920139312744
lr^<0> = tensor([0.0567], grad_fn=<MulBackward0>)
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
--> inner_i = 2
inner_loss^<2>: 2.3866422176361084
lr^<1> = tensor([0.0720], grad_fn=<MulBackward0>)
lr^<2> = tensor([0.0717], grad_fn=<MulBackward0>)
----> outer_i = 0
-> outer_loss/meta_loss^<0>: 4.021303176879883
child_model.fc.weight.grad = None
hidden.grad = None
eta.fc.weight = None
---> Done
I just saw denis responded so I will check his code too...
@egrefen Sorry for the spam, but this has to be some sort of bug because when I add all the parameters inside the _update
, call backward on the sum and then print the gradient the gradients I expect to be non-zero are indeed non-zero:
==> hidden.grad = tensor([[0.0373]])
==> eta.fc.weight.grad = tensor([[-0.0882]])
but when I do it otuside of _update
(in the inner loop) and do diffopt.step(inner_loss)
I get they are incorrectly None:
===> hidden.grad = None
===> eta.fc.weight.grad = None
This must be some sort of bug somewhere, because I have not done anything with the weights after step and they should be the same as they were inside the _update
function.
For reference I will paste the new code with the new print statements:
'''
Single task MAML:
MAML: min_{theta} sum_t L^val( theta - eta* Grad L^train(theta) )
T-step MAML: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) ) ~ min_{theta} sum_t L^val( argmin L^train(theta) )
Innerloop: theta^{T} - eta* Grad L^train(theta^{T}) ~ argmin L^train(theta)
single task MAML: min_{theta} L^val( theta - eta* Grad L^train(theta) )
based on MAML example: https://github.com/facebookresearch/higher/blob/master/examples/maml-omniglot.py
'''
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, eta, prev_lr, hidden, meta_opt):
defaults = {'eta':eta, 'prev_lr':prev_lr, 'hidden':hidden, 'meta_opt':meta_opt}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
print(f'==> hidden.grad = {hidden.grad}')
print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
print()
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
#meta_params = itertools.chain(child_model.parameters(),eta.parameters(),[hidden])
meta_params = itertools.chain(eta.parameters(),[hidden])
meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden, hidden=hidden, meta_opt=meta_opt)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
nb_inner_steps = 3
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
#with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
meta_opt.zero_grad()
if inner_i >= nb_inner_steps:
break
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'--> inner_i = {inner_i}')
print(f'inner_loss^<{inner_i}>: {inner_loss}')
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
diffopt.step(inner_loss) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
##
p_tot = sum([ p.sum() for p in fmodel.parameters() ])
p_tot.backward()
print(f'===> hidden.grad = {hidden.grad}')
print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
print()
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
p_tot = sum([ p.sum() for p in fmodel.parameters() ])
p_tot.backward()
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
#meta_loss = meta_loss + inner_loss
#make_dot(meta_loss).render('meta_loss',format='png')
#meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
print(f'eta.fc.weight = {eta.fc.weight.grad}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
if __name__ == "__main__":
main()
print('---> Done\a')
perhaps it's this line of code
right after update:
self._update(grouped_grads)
new_params = params[:]
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if self._track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_()
if self._fmodel is not None:
self._fmodel.update_params(new_params)
probably this function, since it's the only part that re-assigns the attributes:
Re-opening this issue so I remember to check it this week (if I can find the time).
I feel so close...but something about the way fmodules (and fmodel) are being updating is breaking my computation graph...
I will be making comments to record/track my progress in the debugging as I learn stuff about the bug.
I tried printing the parameters in the update function _update_patched_params
and displaying the computation graph to see if higher is breaking the computation graph. It seems that function is not break it as shown the two pics (node I've removed biases to make graphs simpler):
1)
2)
Code:
def _update_patched_params(
fmodule: _MonkeyPatchBase,
params_box: _typing.Sequence[_typing.List[_torch.Tensor]],
params_offset: int
) -> int:
num_params = len([1 for p in fmodule._parameters.values() if p is not None])
child_params_offset = params_offset + num_params
for name, child in fmodule._modules.items():
child_params_offset = _update_patched_params(
child, params_box, child_params_offset
)
#p_tot = 0
for name, param in zip(fmodule._param_names,params_box[0][params_offset:params_offset + num_params]):
#delattr(fmodule, name)
setattr(fmodule, name, param)
make_dot(param.sum()).render(
filename='param_sum1',
format='png'
)
#st()
#print(name)
#p_tot += param.sum()
#st()
#print(f'> p_tot = {p_tot}')
#p_tot.backward()
#print(f'===> hidden.grad = {hidden.grad}')
#print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
return child_params_offset
My code should work even if the original model
or fmodel
is not trainable:
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
child_model.conv1.weight.requires_grad = False
child_model.fc.weight.requires_grad = False
although this is not super important...
Seems that it does not backpropage to my neural net step size (named eta
) if I try to call backwards inside step
(but it does inside my implementation of _update
).
Code and output:
def step(
self,
loss: _torch.Tensor,
params: _typing.Iterable[_torch.Tensor] = None,
override: _typing.Optional[_OverrideType] = None,
grad_callback: _typing.Optional[_GradCallbackType] = None,
eta=None,
**kwargs
) -> _typing.Iterable[_torch.Tensor]:
r"""Perform a model update.
This would be used by replacing the normal sequence::
opt.zero_grad()
loss.backward()
opt.step()
with::
diffopt.step(loss)
Args:
loss: the loss tensor.
params (optional): the parameters with regard to which we measure
the loss. These must be provided if the differentiable optimizer
did not receive a patched model with a view over its own fast
weights at initialisation. If there is such a model, and params
are provided, they will overwrite the params of the encapsulated
model.
override (optional): a dictionary mapping optimizer settings (i.e.
those which would be passed to the optimizer constructor or
provided within parameter groups) to either singleton lists of
override values, or to a list of override values of length equal
to the number of parameter groups. If a single override is
provided for a keyword, it is used for all parameter groups. If
a list is provided, the ``i``\ th element of the list overrides
the corresponding setting in the ``i``\ th parameter group. This
permits the passing of tensors requiring gradient to
differentiable optimizers for use as optimizer settings. Setting
override here has highest precedence, i.e. it will override any
tensors provided as override during the creation of the
differentiable optimizer, where there is name clash.
grad_callback: (optional) a single argument function which will be
applied to a list of gradients of parameters, which respects the
order specified by ``reference_params``. This can be used to
apply a function, such as gradient clipping, to all (or a
subset) of these gradients every time the step function is
called. This callback overrides the default provided when
constructing the differentiable optimizer.
Returns:
The updated parameters, which will individually have ``grad_fn``\ s
of their own. If the optimizer has an encapsulated patched model,
its view over its own fast weights will be updated with these
params.
"""
print('---------> IN .step(loss)')
#st()
#eta = eta[0]
# Deal with override
if override is not None:
self._apply_override(override)
if self._fmodel is None or self._fmodel.fast_params is None:
if params is None:
raise ValueError(
"params kwarg must be passed to step if the differentiable "
"optimizer doesn't have a view on a patched model with "
"params."
)
else:
params = self._fmodel.fast_params if params is None else params
params = list(params)
# This allows us to gracefully deal with cases where params are frozen.
grad_targets = [
p if p.requires_grad else _torch.tensor([], requires_grad=True)
for p in params
]
all_grads = _torch.autograd.grad(
loss,
grad_targets,
create_graph=self._track_higher_grads,
allow_unused=True # boo
)
if grad_callback is not None:
all_grads = grad_callback(all_grads)
elif self._grad_callback is not None:
all_grads = self._grad_callback(all_grads)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
self._update(grouped_grads)
new_params = params[:]
print(f'self._track_higher_grads = {self._track_higher_grads}')
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if self._track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_()
p_tot = 0
for p in new_params:
p_tot += p.sum()
p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
print(f'eta.fc.grad = {eta.fc.weight.grad}')
return new_params
Output:
---------> IN .step(loss)
self._track_higher_grads = True
eta.fc.grad = None
it seems it doesn't work either right outside of _update(grouped_grads)
:
code and output:
def step(
self,
loss: _torch.Tensor,
params: _typing.Iterable[_torch.Tensor] = None,
override: _typing.Optional[_OverrideType] = None,
grad_callback: _typing.Optional[_GradCallbackType] = None,
eta=None,
**kwargs
) -> _typing.Iterable[_torch.Tensor]:
r"""Perform a model update.
This would be used by replacing the normal sequence::
opt.zero_grad()
loss.backward()
opt.step()
with::
diffopt.step(loss)
Args:
loss: the loss tensor.
params (optional): the parameters with regard to which we measure
the loss. These must be provided if the differentiable optimizer
did not receive a patched model with a view over its own fast
weights at initialisation. If there is such a model, and params
are provided, they will overwrite the params of the encapsulated
model.
override (optional): a dictionary mapping optimizer settings (i.e.
those which would be passed to the optimizer constructor or
provided within parameter groups) to either singleton lists of
override values, or to a list of override values of length equal
to the number of parameter groups. If a single override is
provided for a keyword, it is used for all parameter groups. If
a list is provided, the ``i``\ th element of the list overrides
the corresponding setting in the ``i``\ th parameter group. This
permits the passing of tensors requiring gradient to
differentiable optimizers for use as optimizer settings. Setting
override here has highest precedence, i.e. it will override any
tensors provided as override during the creation of the
differentiable optimizer, where there is name clash.
grad_callback: (optional) a single argument function which will be
applied to a list of gradients of parameters, which respects the
order specified by ``reference_params``. This can be used to
apply a function, such as gradient clipping, to all (or a
subset) of these gradients every time the step function is
called. This callback overrides the default provided when
constructing the differentiable optimizer.
Returns:
The updated parameters, which will individually have ``grad_fn``\ s
of their own. If the optimizer has an encapsulated patched model,
its view over its own fast weights will be updated with these
params.
"""
print('---------> IN .step(loss)')
#st()
#eta = eta[0]
# Deal with override
if override is not None:
self._apply_override(override)
if self._fmodel is None or self._fmodel.fast_params is None:
if params is None:
raise ValueError(
"params kwarg must be passed to step if the differentiable "
"optimizer doesn't have a view on a patched model with "
"params."
)
else:
params = self._fmodel.fast_params if params is None else params
params = list(params)
# This allows us to gracefully deal with cases where params are frozen.
grad_targets = [
p if p.requires_grad else _torch.tensor([], requires_grad=True)
for p in params
]
all_grads = _torch.autograd.grad(
loss,
grad_targets,
create_graph=self._track_higher_grads,
allow_unused=True # boo
)
if grad_callback is not None:
all_grads = grad_callback(all_grads)
elif self._grad_callback is not None:
all_grads = self._grad_callback(all_grads)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
self._update(grouped_grads)
p_tot = 0
for p in params[:]:
p_tot += p.sum()
p_tot.backward()
# new_params = params[:]
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# p_tot = 0
# for p in new_params:
# p_tot += p.sum()
# p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
print(f'eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
output:
> p_tot = 3.033494472503662
---------> IN .step(loss)
eta.fc.grad = None
I tried checking if the gradients of the learnable nn step size eta
inside my custom _update
did populate the gradients and it does:
output:
---------> IN .step(loss)
==> eta.fc.weight.grad = tensor([[-0.0260]])
code:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
# fake returns
self.param_groups[0]['prev_lr'] = lr
p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
I tried returning the new parameters from the _update
function inside step
and then compute the backward pass but it did not work. This is the code that failed:
new_params = self._update(grouped_grads)
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
# p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
p_tot = 0
for p in new_params:
p_tot += p.sum()
p_tot.backward()
my _update
:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
new_params = group['params']
return new_params
My next attempt is to append the params to a virgin list and bypass any grouped pytorch thing because my hypothesis is that pytorch might be doing stuff under the hood somewhere.
Appending the new parameters myself and bypassing groups
list did not work:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
return new_params
step is:
new_params = self._update(grouped_grads)
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
# p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
p_tot = 0
for p in new_params:
p_tot += p.sum()
p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
print(f'eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
output:
---------> IN .step(loss)
eta.fc.grad = None
Well I thought that now if we returned the value p_tot
directly and called backwards outside of step would populate the gradients of eta
because doing that inside of _update
works. So I did that and it still didn't populate the gradients despite populating them inside _update
.
Code:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
#return new_params
return p_tot
code of _step
:
#new_params = self._update(grouped_grads)
p_tot = self._update(grouped_grads)
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
# p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# p_tot = 0
# for p in new_params:
# p_tot += p.sum()
p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
print(f'p_tot = {p_tot}')
print(p_tot)
print(f'eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
as we can see p_tot
does have a grad_fn
function which makes this bug really mysterious to me.
Ok so the last thing that occurred to me is to print the computation graphs inside _update
and right outside it (inside of step). I expected the graphs to be different but they are exactly the same. Which puzzles me even more:
p_tot_inside_update
Inside out update:
p_tot_inside_step
Outside out update (inside step):
code for reference:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
#return new_params
make_dot(p_tot).render('p_tot_inside_update', format='png')
return p_tot
outside:
#new_params = self._update(grouped_grads)
p_tot = self._update(grouped_grads)
make_dot(p_tot).render('p_tot_inside_step', format='png')
I made my own step
function and commented out nearly everything and the gradients for eta
are still not populated.
My suspicion is that the error might be here:
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
#group['params'][i] = params[index].T
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
becauses params seems to contain nn.Parameters
and that have caused me issue in the past.
Code for my my_step
and my _update
:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
#return new_params
#make_dot(p_tot).render('p_tot_inside_update', format='png')
return p_tot, new_params
def my_step(
self,
loss,
params = None,
override = None,
grad_callback = None,
eta=None,
**kwargs
):
print('---------> IN MY .step(loss)')
#st()
#eta = eta[0]
# Deal with override
# if override is not None:
# self._apply_override(override)
# if self._fmodel is None or self._fmodel.fast_params is None:
# if params is None:
# raise ValueError(
# "params kwarg must be passed to step if the differentiable "
# "optimizer doesn't have a view on a patched model with "
# "params."
# )
# else:
# params = self._fmodel.fast_params if params is None else params
#params = self._fmodel.fast_params if params is None else params
params = self._fmodel.fast_params
params = list(params)
# This allows us to gracefully deal with cases where params are frozen.
# grad_targets = [
# p if p.requires_grad else torch.tensor([], requires_grad=True)
# for p in params
# ]
grad_targets = params
# all_grads = torch.autograd.grad(
# loss,
# grad_targets,
# create_graph=self._track_higher_grads,
# allow_unused=True # boo
# )
all_grads = torch.autograd.grad(
loss,
grad_targets
)
# if grad_callback is not None:
# all_grads = grad_callback(all_grads)
# elif self._grad_callback is not None:
# all_grads = self._grad_callback(all_grads)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
#group['params'][i] = params[index].T
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
#new_params = self._update(grouped_grads)
p_tot, new_params = self._update(grouped_grads)
#make_dot(p_tot).render('p_tot_inside_step', format='png')
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
#p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# p_tot = 0
# for p in new_params:
# p_tot += p.sum()
#p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
print(f'p_tot = {p_tot}')
print(p_tot)
print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
output:
---------> IN MY .step(loss)
p_tot = -1.2309236526489258
tensor(-1.2309, grad_fn=<AddBackward0>)
+++>>> eta.fc.grad = None
Ok some progress, I was able to have the gradients be non-zero inside my custom step function by indexing the self.param_groups[0][trainable_opt_param]
directly AND updating fmodel
inside my _update
function. My suspicion is that self.param_groups
is being deepcopied somewhere without my permission
Code:
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#
self._fmodel.update_params(new_params)
#x = torch.randn([4,3,32,32])
#y = self._fmodel(x)
#y.sum().backward()
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
#return new_params
#make_dot(p_tot).render('p_tot_inside_update', format='png')
return p_tot, new_params
def my_step(
self,
loss,
params = None,
override = None,
grad_callback = None,
eta=None,
**kwargs
):
print('---------> IN MY .step(loss)')
eta = self.param_groups[0]['eta']
hidden = self.param_groups[0]['hidden']
#st()
#eta = eta[0]
# Deal with override
# if override is not None:
# self._apply_override(override)
# if self._fmodel is None or self._fmodel.fast_params is None:
# if params is None:
# raise ValueError(
# "params kwarg must be passed to step if the differentiable "
# "optimizer doesn't have a view on a patched model with "
# "params."
# )
# else:
# params = self._fmodel.fast_params if params is None else params
#params = self._fmodel.fast_params if params is None else params
params = self._fmodel.fast_params
params = list(params)
# This allows us to gracefully deal with cases where params are frozen.
# grad_targets = [
# p if p.requires_grad else torch.tensor([], requires_grad=True)
# for p in params
# ]
grad_targets = params
# all_grads = torch.autograd.grad(
# loss,
# grad_targets,
# create_graph=self._track_higher_grads,
# allow_unused=True # boo
# )
all_grads = torch.autograd.grad(
loss,
grad_targets
)
# if grad_callback is not None:
# all_grads = grad_callback(all_grads)
# elif self._grad_callback is not None:
# all_grads = self._grad_callback(all_grads)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
#group['params'][i] = params[index].T
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
#new_params = self._update(grouped_grads)
p_tot, new_params = self._update(grouped_grads)
#make_dot(p_tot).render('p_tot_inside_step', format='png')
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
#p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# p_tot = 0
# for p in new_params:
# p_tot += p.sum()
#p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
#self._fmodel.update_params(new_params)
x = torch.randn([4,3,32,32])
y = self._fmodel(x)
y.sum().backward()
#print(f'p_tot = {p_tot}')
#print(p_tot)
print(f'+++>>> hidden.grad = {hidden.grad}')
print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
st()
return
output:
----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
+++>>> hidden.grad = tensor([[0.0102]])
+++>>> eta.fc.grad = tensor([[-0.0225]])
Ok so it seems that I can see the gradients only if I index my params from .param_groups
but my original models somehow have been detached or deep copied or something...disassociated from the original definition...
nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
nb_inner_steps = 5
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
#with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
meta_opt.zero_grad()
if inner_i >= nb_inner_steps:
break
logits = fmodel(inner_inputs)
print(type(fmodel))
print(fmodel)
inner_loss = criterion(logits, inner_targets)
print(f'--> inner_i = {inner_i}')
print(f'inner_loss^<{inner_i}>: {inner_loss}')
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
p_tot = sum([ p.sum() for p in fmodel.parameters() ])
#print(f'> p_tot = {p_tot}')
print('\n----- DEBUGGING print statements after this line ----')
#diffopt.step(inner_loss,eta=eta)
#diffopt.my_step(inner_loss, eta=eta)
fmodel = diffopt.my_step(inner_loss)
#step(diffopt, inner_loss,eta=eta)
x = torch.randn([4,3,32,32])
y = fmodel(x)
#y = diffopt._fmodel(x)
y.sum().backward()
#sys.exit()
#new_params = diffopt.step(inner_loss, eta=eta) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
#p = list(fmodel.parameters())[1]
#params = list(child_model.parameters())+list(eta.parameters())+list([hidden])
#c_params = list(child_model.parameters())
#params = {'child_model[0]':c_params[0], 'child_model[1]':c_params[1], 'eta_params':eta.fc.weight, 'hidden':hidden }
# make_dot(p.sum(),params=params).render(
# filename='param_sum_inner_loop',
# format='png'
# )
# make_dot(p.sum()).render(
# filename='param_sum_inner_loop_no_names_5',
# format='png'
# )
##
#p_tot_new = sum([ p.sum() for p in new_params ])
#p_tot = sum([ p.sum() for p in fmodel.parameters() ])
#print(f'> p_tot = {p_tot}')
#print(f'same?: {p_tot_new == p_tot}')
#p_tot.backward()
#p_tot.backward()
#p.sum().backward()
eta = diffopt.param_groups[0]['eta']
hidden = diffopt.param_groups[0]['hidden']
print(f'===> hidden.grad = {hidden.grad}')
print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
st()
works:
----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
lr^<0> = tensor([0.0389], grad_fn=<MulBackward0>)
===> hidden.grad = tensor([[0.0011]])
===> eta.fc.weight.grad = tensor([[-0.0016]])
Ok so the issue is definitively some sort of deep copy. Code inside my inner loop (inside the context manager):
eta2 = diffopt.param_groups[0]['eta']
hidden2 = diffopt.param_groups[0]['hidden']
h = hidden
print(f'hidden is hidden2 = {hidden is hidden2}')
print(f'hidden is h = {hidden is h}')
output:
----- DEBUGGING print statements after this line ----
---------> IN MY .step(loss)
lr^<0> = tensor([0.0450], grad_fn=<MulBackward0>)
hidden is hidden2 = False
hidden is h = True
===> hidden.grad = None
===> eta.fc.weight.grad = None
Current fix is to only index the optimizer's parameters from the diffopt.param_groups
AND update the fmodel
inside your own custom _update
function. Somehow outside of that, it not longer points to the original parameters...
Ok indeed that does fix it. So my current solution is:
1) update the parameters of the trainable optimizer inside your own _update
function
2) inside the context manager and inner loop assign your optimizer variables each time so that they don't get lost:
eta = diffopt.param_groups[0]['eta']
hidden = diffopt.param_groups[0]['hidden']
@egrefen when you have time it would be nice if you take a look at this because I am afraid there might be some subtle thing I might have missed. But at the very least the gradients of my learning rate are not being populated.
Current code that seems to work (unsure if there is some subtle bug I might not know about):
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, eta, prev_lr, hidden, meta_opt):
defaults = {'eta':eta, 'prev_lr':prev_lr, 'hidden':hidden, 'meta_opt':meta_opt}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
meta_opt = self.param_groups[0]['meta_opt']
hidden = self.param_groups[0]['hidden']
prev_lr = self.param_groups[0]['prev_lr']
eta = self.param_groups[0]['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.1*eta(prev_lr).view(1)
p_tot = 0
new_params = []
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
#group['params'][p_idx] = _add(p, -group['lr'], g)
p_new = p + lr*g
group['params'][p_idx] = p_new
p_tot += p_new.sum()
#make_dot(p_new.sum()).render('p_new',format='png')
#print()
new_params.append( p_new )
# fake returns
self.param_groups[0]['prev_lr'] = lr
#
self._fmodel.update_params(new_params)
#x = torch.randn([4,3,32,32])
#y = self._fmodel(x)
#y.sum().backward()
#p_tot.backward()
#print(f'==> hidden.grad = {hidden.grad}')
#print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#meta_opt.zero_grad()
# print(f'==> hidden.grad = {hidden.grad}')
# print(f'==> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#new_params = group['params']
#return new_params
#make_dot(p_tot).render('p_tot_inside_update', format='png')
return p_tot, new_params
def my_step(
self,
loss,
params = None,
override = None,
grad_callback = None,
eta=None,
**kwargs
):
print('---------> IN MY .step(loss)')
eta = self.param_groups[0]['eta']
hidden = self.param_groups[0]['hidden']
#st()
#eta = eta[0]
# Deal with override
# if override is not None:
# self._apply_override(override)
# if self._fmodel is None or self._fmodel.fast_params is None:
# if params is None:
# raise ValueError(
# "params kwarg must be passed to step if the differentiable "
# "optimizer doesn't have a view on a patched model with "
# "params."
# )
# else:
# params = self._fmodel.fast_params if params is None else params
#params = self._fmodel.fast_params if params is None else params
params = self._fmodel.fast_params
params = list(params)
# This allows us to gracefully deal with cases where params are frozen.
# grad_targets = [
# p if p.requires_grad else torch.tensor([], requires_grad=True)
# for p in params
# ]
grad_targets = params
# all_grads = torch.autograd.grad(
# loss,
# grad_targets,
# create_graph=self._track_higher_grads,
# allow_unused=True # boo
# )
all_grads = torch.autograd.grad(
loss,
grad_targets
)
# if grad_callback is not None:
# all_grads = grad_callback(all_grads)
# elif self._grad_callback is not None:
# all_grads = self._grad_callback(all_grads)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
#group['params'][i] = params[index].T
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
#new_params = self._update(grouped_grads)
p_tot, new_params = self._update(grouped_grads)
#make_dot(p_tot).render('p_tot_inside_step', format='png')
# p_tot = 0
# for p in params[:]:
# p_tot += p.sum()
#p_tot.backward()
#new_params = params[:]
#st()
# print(f'self._track_higher_grads = {self._track_higher_grads}')
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if self._track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# p_tot = 0
# for p in new_params:
# p_tot += p.sum()
#p_tot.backward()
# if self._fmodel is not None:
# self._fmodel.update_params(new_params)
# print()
# st()
# set_attr(self._fmodel, names, val)
# del_attr(self._fmodel, names)
#self._fmodel.update_params(new_params)
# x = torch.randn([4,3,32,32])
# y = self._fmodel(x)
# y.sum().backward()
#print(f'p_tot = {p_tot}')
#print(p_tot)
# print(f'+++>>> hidden.grad = {hidden.grad}')
# print(f'+++>>> eta.fc.grad = {eta.fc.weight.grad}')
#return new_params
#st()
return self._fmodel
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
#child_model.conv1.weight.requires_grad = False
#child_model.fc.weight.requires_grad = False
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
#meta_params = itertools.chain(child_model.parameters(),eta.parameters(),[hidden])
meta_params = itertools.chain(eta.parameters(),[hidden])
meta_opt = torch.optim.Adam(meta_params, lr=1e-3)
inner_opt = MySGD(child_model.parameters(), eta=eta, prev_lr=hidden, hidden=hidden, meta_opt=meta_opt)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
nb_outer_steps = 1 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training/MAML; minimize innerloop: theta^{T} - eta * Grad L^train(theta^{T}) ~ argmin L^train(theta)
nb_inner_steps = 5
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
#with higher.innerloop_ctx(child_model, inner_opt) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
meta_opt.zero_grad()
if inner_i >= nb_inner_steps:
break
logits = fmodel(inner_inputs)
print(type(fmodel))
print(fmodel)
inner_loss = criterion(logits, inner_targets)
print(f'--> inner_i = {inner_i}')
print(f'inner_loss^<{inner_i}>: {inner_loss}')
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["prev_lr"]}')
p_tot = sum([ p.sum() for p in fmodel.parameters() ])
#print(f'> p_tot = {p_tot}')
print('\n----- DEBUGGING print statements after this line ----')
#diffopt.step(inner_loss,eta=eta)
#diffopt.my_step(inner_loss, eta=eta)
#fmodel = diffopt.my_step(inner_loss)
diffopt.my_step(inner_loss)
#step(diffopt, inner_loss,eta=eta)
#x = torch.randn([4,3,32,32])
#y = fmodel(x)
#y = diffopt._fmodel(x)
#y.sum().backward()
#sys.exit()
#new_params = diffopt.step(inner_loss, eta=eta) # changes params P[t+1] using P[t] and loss[t] in a differentiable manner
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["prev_lr"]}')
#p = list(fmodel.parameters())[1]
#params = list(child_model.parameters())+list(eta.parameters())+list([hidden])
#c_params = list(child_model.parameters())
#params = {'child_model[0]':c_params[0], 'child_model[1]':c_params[1], 'eta_params':eta.fc.weight, 'hidden':hidden }
# make_dot(p.sum(),params=params).render(
# filename='param_sum_inner_loop',
# format='png'
# )
# make_dot(p.sum()).render(
# filename='param_sum_inner_loop_no_names_5',
# format='png'
# )
##
#p_tot_new = sum([ p.sum() for p in new_params ])
#p_tot = sum([ p.sum() for p in fmodel.parameters() ])
#print(f'> p_tot = {p_tot}')
#print(f'same?: {p_tot_new == p_tot}')
#p_tot.backward()
#p_tot.backward()
#p.sum().backward()
eta = diffopt.param_groups[0]['eta']
hidden = diffopt.param_groups[0]['hidden']
#h = hidden
#print(f'hidden is hidden2 = {hidden is hidden2}')
#print(f'hidden is h = {hidden is h}')
print(f'===> hidden.grad = {hidden.grad}')
print(f'===> eta.fc.weight.grad = {eta.fc.weight.grad}')
#print()
#st()
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print()
#p_tot = sum([ p.sum() for p in fmodel.parameters() ])
#p_tot.backward()
#print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
#meta_loss = meta_loss + inner_loss
#make_dot(meta_loss).render('meta_loss',format='png')
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
print(f'eta.fc.weight = {eta.fc.weight.grad}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
if __name__ == "__main__":
main()
print('---> Done\a')
Commenting out this line https://github.com/facebookresearch/higher/blob/8f0716fb1663218324c02dabdba26b639959cfb6/higher/optim.py#L101
removes the bug where my meta-parameters (learning rate) used to not being updated but now introduced new issues...why does it have to be a deep copy?
notice that the fact things are deep copied means if those parameters are given to the meta-optimizer, then the meta-optimizer will not update the right version of the optimization nn's (in fact they won't have gradients since they aren't even pointing to the same object...).
Trying to find a work around, seems that re-assigning param_groups
list does the trick...though a less messy solution would be nice.
Thanks for the detailed comments. You've given me a lot to look at so it's probably very helpful. I would encourage you to put stack traces in comments when you see crashes too as this helps debug things.
Just to manage expectations: I'm currently the only person supporting this project internally and have other responsibilities so cannot guarantee I can resolve this immediately (as there's a more pressing issue with second order backprop being broken for higher with pytorch v1.4 and also some memory leak issues). But I'll get to it as soon as I can!
Thanks for the detailed comments. You've given me a lot to look at so it's probably very helpful. I would encourage you to put stack traces in comments when you see crashes too as this helps debug things.
Just to manage expectations: I'm currently the only person supporting this project internally and have other responsibilities so cannot guarantee I can resolve this immediately (as there's a more pressing issue with second order backprop being broken for higher with pytorch v1.4 and also some memory leak issues). But I'll get to it as soon as I can!
No worries, I understand. I think I got a temporary fix that will work for what I want to do.
I am happy to answer questions if you need help.
I will try to summarize the issue so that it's easier for you to go through this once you decide to fix the issue (and post my temporary solution as it could provide how I avoided the issue).
That would be super helpful. Thanks for understanding, and for all the detailed comments.
Part of my current solution is to define this function (for updating the meta-optimizer so that it gets the new copies of the new parameters higher made a deep copy of):
def load_new_params(self, params):
self.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
self.add_param_group(param_group)
but it's a copy paste of the original init method...which also has something called self.state
which I am unsure what it's for or if it breaks my optimizer in some unknown way...
Perhaps I should create a new param_groups
list and load it with:
self.__setstate__({'state': state, 'param_groups': param_groups})
related:
Ok, it seems this works for me for now (unless there is a subtle bug from higher or optim that I am not aware of...e.g. how self.state
works or something else):
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
def load_new_params(optimizer, params):
optimizer.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
optimizer.add_param_group(param_group)
class MySGD(Optimizer):
def __init__(self, params, trainable_opt_params, trainable_opt_state):
defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_params']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.01*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
p_new = p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
# update model
new_params = self.param_groups[0]['params']
new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
self._fmodel.update_params(new_params)
def my_step(
self,
loss,
params = None,
override = None,
grad_callback = None,
eta=None,
**kwargs
):
params = self._fmodel.fast_params
params = list(params)
grad_targets = params
all_grads = torch.autograd.grad(
loss,
grad_targets,
create_graph=self._track_higher_grads,
allow_unused=True # boo
)
grouped_grads = []
for group, mapping in zip(self.param_groups, self._group_to_param_list):
grads = []
for i, index in enumerate(mapping):
#group['params'][i] = params[index].T
group['params'][i] = params[index]
grads.append(all_grads[index])
grouped_grads.append(grads)
self._update(grouped_grads)
# WARNING DON'T UPDATE PARAMETERS IN STEP
return self._fmodel
def _track_higher_grads_for_new_params(self, new_params, track_higher_grads):
'''
For the new params, set if we are tracking higher order grads for them or detaching them for the computation graph.
'''
for group, mapping in zip(self.param_groups, self._group_to_param_list):
for p, index in zip(group['params'], mapping):
if track_higher_grads:
new_params[index] = p
else:
new_params[index] = p.detach().requires_grad_()
return new_params
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
lr = 0.01
meta_params = []
meta_params.append( {'params': hidden, 'lr':lr} )
meta_params.append( {'params': eta.parameters(), 'lr':lr} )
#meta_opt = torch.optim.SGD(meta_params)
meta_opt = torch.optim.Adam(meta_params)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
nb_outer_steps = 3 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# do inner-training: ~ argmin L^train(theta)
nb_inner_steps = 3
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
print('==== Inner Loop ====')
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
meta_opt.zero_grad()
if inner_i >= nb_inner_steps:
break
print(f'-> inner_i = {inner_i}')
print(f'hidden^<{inner_i}> = {hidden}')
#print(f'eta.fc.weight^<{inner_i}> = {eta.fc.weight}')
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
diffopt.my_step(inner_loss)
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
eta = diffopt.param_groups[0]['trainable_opt_params']['eta']
hidden = diffopt.param_groups[0]['trainable_opt_params']['hidden']
print(f'hidden^<{inner_i}> = {hidden}')
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print('\n---- Outer loop print statements ----')
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
assert hidden.grad is not None
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
# param_groups = meta_opt.param_groups
# print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
# print(f'hidden is param_groups[0][params][0] : {hidden is param_groups[0]["params"][0]}')
# print(f'mutate meta_opt')
#param_groups[0]['params'] = [hidden]
new_meta_params = []
new_meta_params.append( {'params': hidden, 'lr':lr} )
new_meta_params.append( {'params': eta.parameters(), 'lr':lr} )
# print(f'new_meta_params = {new_meta_params}')
# print(f'param_groups = {param_groups}')
load_new_params(meta_opt, new_meta_params)
# print(f'meta_opt.param_groups = {meta_opt.param_groups}')
# print(meta_opt.param_groups is param_groups)
# print(param_groups[0]['params'][0] is meta_opt.param_groups[0]['params'][0])
# print(hidden is param_groups[0]['params'][0])
# print(hidden is meta_opt.param_groups[0]['params'][0])
# param_groups[1]['params'] = eta.fc.weight
# print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
# print(f'hidden is param_groups[0][params][0] : {hidden is param_groups[0]["params"][0]}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print(f'> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
# print(f'hidden == param_groups[0][params][0] : {hidden == param_groups[0]["params"][0]}')
# print(f'hidden is param_groups[0][params] : {hidden is param_groups[0]["params"][0]}')
#print(f'> eta.fc.weight^<{outer_i}> = {meta_opt.param_groups[1]["params"].T}')
print()
if __name__ == "__main__":
main()
print('---> Done\a')
Ok nearly got it to work. I removed the deep copy and the context manager but now it's complaining that I'm trying to call backwards twice on the same computation graph...
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
def load_new_params(optimizer, params):
optimizer.param_groups = []
param_groups = list(params)
if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list")
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
optimizer.add_param_group(param_group)
def reload_param_groups(opt, params):
if isinstance(params, torch.Tensor):
raise TypeError("params argument given to the optimizer should be "
"an iterable of Tensors or dicts, but got " +
torch.typename(params))
# replace params
params = list(params)
if isinstance(params[0], dict):
raise ValueError(f'The hacked higher version does not support proper pytorch grouped params yet.')
opt.param_groups[0]['params'] = params
# opt.param_groups = []
# param_groups = list(params)
# if len(param_groups) == 0:
# raise ValueError("optimizer got an empty parameter list")
# if not isinstance(param_groups[0], dict):
# param_groups = [{'params': param_groups}]
# for param_group in param_groups:
# opt.add_param_group(param_group)
class MySGD(Optimizer):
def __init__(self, params, trainable_opt_params, trainable_opt_state):
defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_params']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.01*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
p_new = p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
# update model
# new_params = self.param_groups[0]['params']
# new_params = self._track_higher_grads_for_new_params(new_params, self._track_higher_grads)
# self._fmodel.update_params(new_params)
# def my_step2(
# self,
# loss,
# params = None,
# override = None,
# grad_callback = None,
# eta=None,
# **kwargs
# ):
# # Deal with override
# if override is not None:
# self._apply_override(override)
# if self._fmodel is None or self._fmodel.fast_params is None:
# if params is None:
# raise ValueError(
# "params kwarg must be passed to step if the differentiable "
# "optimizer doesn't have a view on a patched model with "
# "params."
# )
# else:
# params = self._fmodel.fast_params if params is None else params
# params = list(params)
# # This allows us to gracefully deal with cases where params are frozen.
# grad_targets = [
# p if p.requires_grad else torch.tensor([], requires_grad=True)
# for p in params
# ]
# all_grads = torch.autograd.grad(
# loss,
# grad_targets,
# create_graph=self._track_higher_grads,
# allow_unused=True # boo
# )
# if grad_callback is not None:
# all_grads = grad_callback(all_grads)
# elif self._grad_callback is not None:
# all_grads = self._grad_callback(all_grads)
# grouped_grads = []
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# grads = []
# for i, index in enumerate(mapping):
# group['params'][i] = params[index]
# grads.append(all_grads[index])
# grouped_grads.append(grads)
# self._update(grouped_grads)
# # ---> WARNING DON'T UPDATE PARAMETERS IN STEP <---
# # the code bellow is now done inside of your _update function
# # new_params = params[:]
# # for group, mapping in zip(self.param_groups, self._group_to_param_list):
# # for p, index in zip(group['params'], mapping):
# # if self._track_higher_grads:
# # new_params[index] = p
# # else:
# # new_params[index] = p.detach().requires_grad_()
# # if self._fmodel is not None:
# # self._fmodel.update_params(new_params)
# return self._fmodel
# def _track_higher_grads_for_new_params(self, new_params, track_higher_grads):
# '''
# For the new params, set if we are tracking higher order grads for them or detaching them for the computation graph.
# '''
# for group, mapping in zip(self.param_groups, self._group_to_param_list):
# for p, index in zip(group['params'], mapping):
# if track_higher_grads:
# new_params[index] = p
# else:
# new_params[index] = p.detach().requires_grad_()
# return new_params
higher.register_optim(MySGD, TrainableSGD)
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
# get trainable opt params
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
lr = 0.01
meta_params = []
meta_params.append( {'params': hidden, 'lr':lr} )
meta_params.append( {'params': eta.parameters(), 'lr':lr} )
# get meta optimizer
#meta_opt = torch.optim.SGD(meta_params)
meta_opt = torch.optim.Adam(meta_params)
#
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
#inner_opt = MySGD(eta.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
# diffopt = higher.optim.get_diff_optim(
# inner_opt,
# eta.parameters(), # for this hack it can be anything
# fmodel=None, # None
# device=device,
# override=None, # None default
# track_higher_grads=True # True default
# )
# do meta-training/ outerloop argmin L^val(theta)
nb_outer_steps = 2 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
# sample child_model
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
# do inner-training: ~ argmin L^train(psi)
nb_inner_steps = 3
print('==== Inner Loop ====')
fmodel = higher.patch.monkeypatch(
child_model,
device,
copy_initial_weights=True # True default
)
inner_opt = MySGD(child_model.parameters(), trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
diffopt = higher.optim.get_diff_optim(
inner_opt,
child_model.parameters(), # for this hack it can be anything
fmodel=fmodel, # None
device=device,
override=None, # None default
track_higher_grads=True # True default
)
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
if inner_i >= nb_inner_steps:
break
print(f'-> outer_i = {outer_i}')
print(f'-> inner_i = {inner_i}')
print(f'hidden^<{inner_i}> = {hidden}')
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
#child_model_params = [{'params':child_model.parameters()}]
child_model_params = child_model.parameters()
reload_param_groups(diffopt, child_model_params)
diffopt._fmodel = fmodel
diffopt.step(inner_loss)
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
print(f'hidden^<{inner_i}> = {hidden}')
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print('\n---- Outer loop print statements ----')
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
meta_loss.backward()
print(f'hidden.grad = {hidden.grad}')
assert hidden.grad is not None
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
print()
if __name__ == "__main__":
main()
print('---> Done\a')
error:
Traceback (most recent call last):
File "trainaible_step_no_deep_copy.py", line 305, in <module>
main()
File "trainaible_step_no_deep_copy.py", line 293, in main
meta_loss.backward()
File "/Users/rene/miniconda3/envs/automl-meta-learning/lib/python3.7/site-packages/torch/tensor.py", line 195, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/Users/rene/miniconda3/envs/automl-meta-learning/lib/python3.7/site-packages/torch/autograd/__init__.py", line 99, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I tried deleting the output node meta_loss
with:
del meta_loss
because I was told that removed the computation graph but it did not.
So the issue is that this line of code of higher
self.param_groups = _copy.deepcopy(other.param_groups)
breaks the trainable step size I am trying to build.
I tried uncommenting it before but my code was still breaking.
With a lot of exploration it seems that only when I re-instantiate/rebuild the inner optimizer + differentiable optimizer before every inner loop then the code works (I think...)
import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer, required
import higher
from higher.optim import DifferentiableOptimizer
from higher.optim import DifferentiableSGD
import torchvision
import torchvision.transforms as transforms
from torchviz import make_dot
import copy
import itertools
import sys
from collections import OrderedDict
from pdb import set_trace as st
#mini class to add a flatten layer to the ordered dictionary
class Flatten(nn.Module):
def forward(self, input):
'''
Note that input.size(0) is usually the batch size.
So what it does is that given any input with input.size(0) # of batches,
will flatten to be 1 * nb_elements.
'''
batch_size = input.size(0)
out = input.view(batch_size,-1)
return out # (batch_size, *size)
def get_cifar10():
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
return trainloader, testloader
class MySGD(Optimizer):
def __init__(self, params, trainable_opt_params, trainable_opt_state):
defaults = {'trainable_opt_params':trainable_opt_params, 'trainable_opt_state':trainable_opt_state}
super().__init__(params, defaults)
class TrainableSGD(DifferentiableOptimizer):
def _update(self, grouped_grads, **kwargs):
prev_lr = self.param_groups[0]['trainable_opt_state']['prev_lr']
eta = self.param_groups[0]['trainable_opt_params']['eta']
# start differentiable & trainable update
zipped = zip(self.param_groups, grouped_grads)
lr = 0.01*eta(prev_lr).view(1)
for group_idx, (group, grads) in enumerate(zipped):
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
p_new = p - lr*g
group['params'][p_idx] = p_new
# fake returns
self.param_groups[0]['trainable_opt_state']['prev_lr'] = lr
higher.register_optim(MySGD, TrainableSGD)
def main():
# get dataloaders
trainloader, testloader = get_cifar10()
criterion = nn.CrossEntropyLoss()
hidden = torch.randn(size=(1,1),requires_grad=True)
print(f'-> hidden = {hidden}')
eta = nn.Sequential(OrderedDict([
('fc', nn.Linear(1,1,bias=False)),
('sigmoid', nn.Sigmoid())
]))
lr = 0.01
meta_params = []
meta_params.append( {'params': hidden, 'lr':lr} )
meta_params.append( {'params': eta.parameters(), 'lr':lr} )
#meta_opt = torch.optim.SGD(meta_params)
meta_opt = torch.optim.Adam(meta_params)
# do meta-training/outer training minimize outerloop: min_{theta} sum_t L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
nb_outer_steps = 5 # note, in this case it's the same as number of meta-train steps (but it's could not be the same depending how you loop through the val set)
for outer_i, (outer_inputs, outer_targets) in enumerate(testloader, 0):
meta_opt.zero_grad()
if outer_i >= nb_outer_steps:
break
#
child_model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(in_channels=3,out_channels=2,kernel_size=5,bias=False)),
('relu1', nn.ReLU()),
('Flatten', Flatten()),
('fc', nn.Linear(in_features=28*28*2,out_features=10,bias=False) )
]))
# do inner-training: ~ argmin L^train(theta)
nb_inner_steps = 3
trainable_opt_params = {'eta':eta, 'hidden':hidden}
trainable_opt_state = {'prev_lr':hidden}
child_model_params = [{'params':child_model.parameters()}]
inner_opt = MySGD(child_model_params, trainable_opt_params=trainable_opt_params, trainable_opt_state=trainable_opt_state)
print('==== Inner Loop ====')
with higher.innerloop_ctx(child_model, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
for inner_i, (inner_inputs, inner_targets) in enumerate(trainloader, 0):
if inner_i >= nb_inner_steps:
break
print(f'-> outer_i = {outer_i}')
print(f'-> inner_i = {inner_i}')
print(f'hidden^<{inner_i}> = {hidden}')
logits = fmodel(inner_inputs)
inner_loss = criterion(logits, inner_targets)
print(f'lr^<{inner_i-1}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
diffopt.step(inner_loss)
print(f'lr^<{inner_i}> = {diffopt.param_groups[0]["trainable_opt_state"]["prev_lr"]}')
print(f'hidden^<{inner_i}> = {hidden}')
# compute the meta-loss L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
outer_outputs = fmodel(outer_inputs)
meta_loss = criterion(outer_outputs, outer_targets) # L^val
meta_loss.backward()
#grad_of_grads = torch.autograd.grad(outputs=meta_loss, inputs=eta.parameters()) # dmeta_loss/dw0
print('\n---- Outer loop print statements ----')
print(f'----> outer_i = {outer_i}')
print(f'-> outer_loss/meta_loss^<{outer_i}>: {meta_loss}')
#print(f'child_model.fc.weight.grad = {child_model.fc.weight.grad}')
print(f'hidden.grad = {hidden.grad}')
assert hidden.grad is not None
assert eta.fc.weight is not None
print(f'eta.fc.weight.grad = {eta.fc.weight.grad}')
print(f'> hidden^<{outer_i-1}> = {hidden}') # before update
print(f'> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
meta_opt.step() # meta-optimizer step: more or less theta^<t> := theta^<t> - meta_eta * Grad L^val( theta^{T} - eta* Grad L^train(theta^{T}) )
print(f'>> hidden^<{outer_i}> = {meta_opt.param_groups[0]["params"][0]}') # after update
print(f'>> eta.fc.weight^<{outer_i-1}> = {eta.fc.weight.T}')
print()
if __name__ == "__main__":
main()
print('---> Done\a')
My main feedback from looking at your code is that you are not using the differentiable optimizers as intended. The way you should be doing it is:
TrainableSGD
(imagine it is just this class with frozen parameters) which subclasses torch.optim.Optimizer
.DifferentiableOptimizer
version of this class.override
kwarg here: https://higher.readthedocs.io/en/latest/optim.htmloverride
kwarg to pass your trainable parameters when constructing the differentiable optimizer (you can do this when using innerloop_ctx
).I'm sorry if this answer seems like "you're using it wrong", but you're using it wrong. Please take a close look at @denisyarats' linked source code for a working example.
If this mode of usage does not fit your needs, but that's another matter. If that's the case, please explain what you can't do with this way of doing things that you'd want to do. A significantly simpler minimal example than what's been provided so far might be helpful to this end.
In the meantime, closing this issue as it's not clear there's a bug underlying it.
@egrefen thanks for taking a look at my discussion.
So my issue can be resolved if this line of code is removed:
and becomes:
self.param_groups = other.param_groups
is there a reason why things have to be deep copied? Can higher function without it?
That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.
Use override
, please.
Override is a kwarg for differentiable optims (at creation, or step time, and you can also use it with the context manager) which allows you to use arbitrary tensors instead of values held in the optimizer state. For example, you could override the learning rate with a tensor which requires grad, which would allow you to unroll your loops, take gradient of the meta-loss with regard to the learning rate, and update this tensor.
See https://higher.readthedocs.io/en/latest/optim.html for details, https://github.com/facebookresearch/higher/issues/32#issuecomment-594466772 for a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for an example.
On Fri, Mar 6, 2020 at 5:26 AM brando90 notifications@github.com wrote:
Use override, please.
My apologies if this is a dense question, what does that mean?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/facebookresearch/higher/issues/32?email_source=notifications&email_token=AAIXXKW6VCVDPFU3ME5M323RGCCQ7A5CNFSM4KYFDLMKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEOADYAI#issuecomment-595606529, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAIXXKXOXR2JGVNSKAIMCT3RGCCQ7ANCNFSM4KYFDLMA .
Override is a kwarg for differentiable optims (at creation, or step time, and you can also use it with the context manager) which allows you to use arbitrary tensors instead of values held in the optimizer state. For example, you could override the learning rate with a tensor which requires grad, which would allow you to unroll your loops, take gradient of the meta-loss with regard to the learning rate, and update this tensor. See https://higher.readthedocs.io/en/latest/optim.html for details, #32 (comment) for a similar explanation, and https://github.com/denisyarats/densenet_cifar10 for an example. …
This is not what I want. I am not trying to train the learning rate. I'm trying to have the inner optimizer to be parametrized like for example, the way it's use in meta-lstm meta-learner https://openreview.net/pdf?id=rJY0-Kcll
I will provide a minimal example that makes it easier to help me.
That line of code is important, as we want to safely branch off the state of the optimizer as used in the outer loop and return to it (or not touch it in the first place, which is what we do with the copy here) at the end of the unrolled inner loop.
can you explain what that means?
I suggest a full that implements the optimizer but a trainable step size could be a good example too...
https://discuss.pytorch.org/t/implement-a-meta-trainable-step-size/70396