AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
773 stars 137 forks source link

Run Meta test #27

Open qianjiangcn opened 4 years ago

qianjiangcn commented 4 years ago

Hi, thank you for the code.

The provided scripts seem to be for training only. Is meta test also included in the scripts? Or I should just change the json file : evaluate on test set only to be True? And does the "training phase" means the epoch?

Thank you!

ruihuili commented 4 years ago

Hi @qianjiangcn,

I had the same question, but from reading the code my understanding is that meta test is indeed included in the scripts. Specifically, the def forward() function in few_shot_learning_system.py is where the meta test happens.

In fact, def forward() is the core method shared by training, validation, and testing phases. It implements "a forward outer loop pass on the batch of tasks" (as per the author's comments in the function), which goes through a number of tasks. For each task, the function performs #num_step steps of inner loop updates on the support set, (at each step of update on the support set, may or may not perform 'multi-step-loss-optimisation' based on MAML++ parameter settings), and then optimises on the target set for the specific task.

The training_phase argument is one of those that determine whether a 'multi-step-loss-optimisation' should happen, for both validation and testing this is set false. This is however, different from the training argument to def net_forward() (a function call to VGG classifier forward()), which has always been set true in all 3 times of calling to def net_forward(), in def forward(). This seems complicated but essentially, for meta-testing, after num_stepupdates on the support set, loss optimisation on the target set takes place and the accuracy on the targeet set is considered as a performance metric for a given model (among the top_n_models).

I could be wrong but in case there's any misinterpretation of the code in the above, it would be greatly appreciated if future reader of this thread / the author could kindly point it out.