Closed kgarg8 closed 4 years ago
It might not be clear what the inner_loop
function does. Basically, at the start both params
and meta_model.parameters()
are two identical tensors, but then params
is updated by gradient descent (see following line).
params_history.append(optim(params_history[-1], hparams, create_graph=create_graph))
params_history
holds the inner iterates and only params_history[0]
is equal to hparams = meta_model.parameters()
.
Thus, after the first inner iteration 0.5 * self.reg_param * self.bias_reg_f(hparams, params)
is not zero.
Is it clearer now?
That part is clear now. However, I don't find one-to-one correspondence to the original IMAML algorithm. For example, the implementation here. Can you please clarify how this hypergradient equation is being used in IMAML? Thanks
Our framework computes approximate hypergradients of bilevel problems where the inner problem is defined by a fixed point map (you can find more details in our paper). The implementation that you linked does exactly this using the fixed point method.
The single task iMAML hypergradient is retrieved when the fixed point map is the one step gradient descent map with constant step size of 1 using the iMAML single task training regularized loss and the outer loss is the single task validation loss (see here). This single task hypergradient can be approximated in several ways, in iMAML they use the conjugate gradient method which is also the default in our implementation. Since the validation loss that we define is already divided by the meta batch size, we simply accumulate the single task hypergradients in the .grad
tensor of the meta parameters and use Adam for the meta-parameters updates.
Hope that this further clarifies the code.
Hi,
I was going through imaml.py example code that you have given. I didn't understand the benefit of passing meta_model.parameters() and params as arguments to the inner_loop function here and here. Because they are similar datawise and in fact, this makes the regularization part in train_loss_f redundant since it is always zero.
0.5 * self.reg_param * self.bias_reg_f(hparams, params) -> 0 always
Can you please explain?