facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.59k stars 123 forks source link

Copy diffopt state to original optimizer #93

Open brjathu opened 3 years ago

brjathu commented 3 years ago

Hi, Thanks for making this awesome library.

I am working on a problem, which requires to copy the values/states of fmodel and diffopt back to their original model/optimizer.

I can do this for the model by simply copying the state_dict. model_phi.load_state_dict(fmodel.state_dict())

However, I can't do this for the optimizer. Do you have any suggestions how to do this?

egrefen commented 3 years ago

I'm sure we can work something out. Can you write a minimum "working" (minus the state copy) example?

murrman95 commented 3 years ago

If I may contribute my own example solution.

The following so that data structures contatining tensors only contain leaf-nodes. This is necessary because copy.deepcopy() with tensors only works with leafnodes at the moment, and there will be problems in torch.optim.Optimizer.load_state_dict without using this.

def recursive_detach(x):

    if isinstance(x, torch.Tensor):
        x = x.detach()
        return x
    elif isinstance(x,dict):
        return {k:recursive_detach(v) for k,v in x.items()}
    elif isinstance(x,list):
        return [recursive_detach(v) for v in x]
    else:
        return x

Now this function for repackaging the diffopt's state in the way that torch.optim.Optimizer expects it.

def diffopt_state_dict(diffopt):
    param_mappings = {}
    start_index = 0

    def pack_group(group):
        nonlocal start_index
        packed = {k: v for k, v in group.items() if k != 'params'}
        param_mappings.update({id(p): i for i, p in enumerate(group['params'], start_index)
                                                      if id(p) not in param_mappings})
        packed['params'] = [param_mappings[id(p)] for p in group['params']]
        start_index += len(packed['params'])
        return packed

    res = defaultdict(dict)

    param_groups = [pack_group(g) for g in diffopt.param_groups]
    for group_idx, group in enumerate(diffopt.param_groups):
        for p_idx, p in enumerate(group['params']):
            res[p] = {
                k:v for k,v in diffopt.state[group_idx][p_idx].items()
            }

    packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): recursive_detach(v)
                                                  for k, v in res.items()}
    return {
        'state':packed_state,
        'param_groups':param_groups
    }

And now the how this works with the higher.innerloop_ctx context manager

with higher.innerloop_ctx(model,opt) as (fmodel,diffopt):
    # Do some training here
    out = fmodel(inputs)

    inner_loss = loss_fn(out,targets)
    diffopt.step(inner_loss) 

opt.load_state_dict(diffopt_state_dict(diffopt))

I was testing the library by using the innerloop_ctx without any meta learning, and the performance was poor compared to normal training with torch.optim.Adam since the optimizer was being reset every step. With this trick, I can train models normally as if we were using torch.optim.Adam

If you don't mind @egrefen, could I add this to the library? just have a state_dict() method as part of diffopt's base class