Open brjathu opened 3 years ago
I'm sure we can work something out. Can you write a minimum "working" (minus the state copy) example?
If I may contribute my own example solution.
The following so that data structures contatining tensors only contain leaf-nodes. This is necessary because copy.deepcopy() with tensors only works with leafnodes at the moment, and there will be problems in torch.optim.Optimizer.load_state_dict without using this.
def recursive_detach(x):
if isinstance(x, torch.Tensor):
x = x.detach()
return x
elif isinstance(x,dict):
return {k:recursive_detach(v) for k,v in x.items()}
elif isinstance(x,list):
return [recursive_detach(v) for v in x]
else:
return x
Now this function for repackaging the diffopt's state in the way that torch.optim.Optimizer expects it.
def diffopt_state_dict(diffopt):
param_mappings = {}
start_index = 0
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed
res = defaultdict(dict)
param_groups = [pack_group(g) for g in diffopt.param_groups]
for group_idx, group in enumerate(diffopt.param_groups):
for p_idx, p in enumerate(group['params']):
res[p] = {
k:v for k,v in diffopt.state[group_idx][p_idx].items()
}
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): recursive_detach(v)
for k, v in res.items()}
return {
'state':packed_state,
'param_groups':param_groups
}
And now the how this works with the higher.innerloop_ctx context manager
with higher.innerloop_ctx(model,opt) as (fmodel,diffopt):
# Do some training here
out = fmodel(inputs)
inner_loss = loss_fn(out,targets)
diffopt.step(inner_loss)
opt.load_state_dict(diffopt_state_dict(diffopt))
I was testing the library by using the innerloop_ctx without any meta learning, and the performance was poor compared to normal training with torch.optim.Adam since the optimizer was being reset every step. With this trick, I can train models normally as if we were using torch.optim.Adam
If you don't mind @egrefen, could I add this to the library? just have a state_dict() method as part of diffopt's base class
Hi, Thanks for making this awesome library.
I am working on a problem, which requires to copy the values/states of fmodel and diffopt back to their original model/optimizer.
I can do this for the model by simply copying the state_dict.
model_phi.load_state_dict(fmodel.state_dict())
However, I can't do this for the optimizer. Do you have any suggestions how to do this?