facebookresearch / higher

higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.
Apache License 2.0
1.58k stars 123 forks source link

Add option to differentiable optimizers to treat buffers as constant #53

Open creiser opened 4 years ago

creiser commented 4 years ago

I am using a step size of 1 for my inner loop. Nevertheless I want to use adaptive optimizers in the inner loop and carry the buffers across outer loop iterations. In the context of e.g. MAML that would not make sense, but for other bi-level optimization problems this is useful. Mathematically this means that we are treating the buffers (e.g. momentum vector) as constants. This can be easily implemented by writing a detached copy into the buffers. Ideally one should have exact control when the computational graph of the buffers should be cut off.

egrefen commented 4 years ago

Hello @creiser. Thanks for your suggestion. I'm afraid I'm being a little slow to understand precisely which buffers you refer to here? Could you please:

  1. give me a toy/concrete example, and also
  2. specify what condition you would like to hold at the end of the a loop (or a diffopt.step call) and ideally also what actually happens?
  3. a proposal for how this feature would be accessible to the user is welcome (e.g. keyword arg when defining loop? model? context? when using diffopt.step? all of the above?)

I'll try to get on it as soon as I have time.

creiser commented 4 years ago

Hello @egrefen,

thanks for the quick reply.

I'm afraid I'm being a little slow to understand precisely which buffers you refer to here

With buffers I mean for example the momentum vector that needs to be stored for e.g. SGD. Other optimizers have other such buffers or states. Currently you store a "differentiable" version of each of these buffers, which makes sense for this episodic training that you need in MAML where you start from a certain initalization and then "simulate" the inner loop steps, but other bi-level algorithms update the parameters during the inner loop "permanently" instead of only "simulating" steps.

give me a toy/concrete example, and also

Below follows my code. As you can see for the buffer there are two versions. A differentiable version and a non-differentiable one. The differentiable version is needed for the computational graph of the current inner loop update and the non-differentiable version is maintained to be treated as a constant for the inner loop updates of coming outer loop iterations.

The difference between my code and yours is that you also write the differentiable version into memory. This makes the computational graph grow with the number of outer loop steps, if you do not reset the inner loop optimizer, i.e. detach the buffers/states or what you are doing right now: You destroy the inner loop optimizers buffers after each outer loop iteration.

class DifferentiableSGD():
    def __init__(self, parameters, lr, momentum = 0, weight_decay = 0, nesterov = False):
        self.parameters = parameters
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.nesterov = nesterov
        if self.momentum != 0:
            self.momentum_buffer = [None for _ in parameters]

    def step(self, grads):
        updated_params = []
        for param_idx, (param, grad) in enumerate(zip(self.parameters, grads)):
            if self.weight_decay != 0:
                grad = grad + param * self.weight_decay
            if self.momentum != 0:
                if self.momentum_buffer[param_idx] is None:
                    self.momentum_buffer[param_idx] = grad.clone().detach_()
                    differentiable_buf = grad
                else:
                    buf = self.momentum_buffer[param_idx]
                    differentiable_buf = buf * self.momentum + grad
                    buf.data = differentiable_buf
                if self.nesterov:
                    grad = grad.add(self.momentum, differentiable_buf)
                else:
                    grad = differentiable_buf
            updated_params.append(param - self.lr * grad)
        return updated_params

specify what condition you would like to hold at the end of the a loop (or a diffopt.step call) and ideally also what actually happens?

There needs to be a detached version of the buffers/states in the inner loop optimizers, which can be used for inner loop updates in coming outer loop iterations.

a proposal for how this feature would be accessible to the user is welcome (e.g. keyword arg when defining loop? model? context? when using diffopt.step? all of the above?)

Will think about it.