prolearner / hypertorch

MIT License
118 stars 16 forks source link

imaml.py issue regarding inner_loop solver #2

Closed kgarg8 closed 4 years ago

kgarg8 commented 4 years ago

Hi,

I was going through imaml.py example code that you have given. I didn't understand the benefit of passing meta_model.parameters() and params as arguments to the inner_loop function here and here. Because they are similar datawise and in fact, this makes the regularization part in train_loss_f redundant since it is always zero. 0.5 * self.reg_param * self.bias_reg_f(hparams, params) -> 0 always Can you please explain?

prolearner commented 4 years ago

It might not be clear what the inner_loop function does. Basically, at the start both params and meta_model.parameters() are two identical tensors, but then params is updated by gradient descent (see following line).

params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))

params_history holds the inner iterates and only params_history[0] is equal to hparams = meta_model.parameters().

Thus, after the first inner iteration 0.5 * self.reg_param * self.bias_reg_f(hparams, params) is not zero.

Is it clearer now?

kgarg8 commented 4 years ago

That part is clear now. However, I don't find one-to-one correspondence to the original IMAML algorithm. For example, the implementation here. Can you please clarify how this hypergradient equation is being used in IMAML? Thanks

prolearner commented 4 years ago

Our framework computes approximate hypergradients of bilevel problems where the inner problem is defined by a fixed point map (you can find more details in our paper). The implementation that you linked does exactly this using the fixed point method.

The single task iMAML hypergradient is retrieved when the fixed point map is the one step gradient descent map with constant step size of 1 using the iMAML single task training regularized loss and the outer loss is the single task validation loss (see here). This single task hypergradient can be approximated in several ways, in iMAML they use the conjugate gradient method which is also the default in our implementation. Since the validation loss that we define is already divided by the meta batch size, we simply accumulate the single task hypergradients in the .grad tensor of the meta parameters and use Adam for the meta-parameters updates.

Hope that this further clarifies the code.