AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
759 stars 137 forks source link

Getting a MAML++ instance. #36

Open brando90 opened 3 years ago

brando90 commented 3 years ago

I was wondering if it is possible to get a MAML++ instance that is diffierenctiable. Example code I have in mind:

meta_learner =  MAMLpp(hyperparams)  # <---- THIS
for i in range(100000):
  train(meta_learner)

is that possible?

e.g. it's simple to get a MAML instance with higher by simply making a normal SGD optimizer differentiable:

        inner_opt = NonDiffMAML(self.base_model.parameters(), lr=self.lr_inner)
...
        for t in range(meta_batch_size):
            spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x[t], spt_y[t], qry_x[t], qry_y[t]
            # Inner Loop Adaptation
            with higher.innerloop_ctx(self.base_model, inner_opt, copy_initial_weights=self.args.copy_initial_weights,
                                      track_higher_grads=self.args.track_higher_grads) as (fmodel, diffopt):
                for i_inner in range(self.args.nb_inner_train_steps):
                    fmodel.train()

                    # base/child model forward pass
                    spt_logits_t = fmodel(spt_x_t)
                    inner_loss = self.args.criterion(spt_logits_t, spt_y_t)
                    # inner_train_err = calc_error(mdl=fmodel, X=S_x, Y=S_y)  # for more advanced learners like meta-lstm

                    # inner-opt update
                    diffopt.step(inner_loss)