Closed robinsongh381 closed 4 years ago
Hi, thanks for interested in our work. In MAML, there are two gradient decent step during training and one gradient decent step during testing.
Training: (i) for train batch at line 150 for update the fast net, (ii) for valid batch at line 170 for update the meta net.
Test: Since our meta net is optimized to fast adapt to new tasks(persona), in the test time, we need to do gradient decent for fast adaptation (a.k.a., few shot learning), which is done by line 190.
Thanks for reading our code and feel free to let us know if you have any question.
Didn't expect such a quick reply but thanks for that !
In your paper, there is a following statement in 3 Experimental Setting
(PAML) a meta-trained model as in Eq.(5), where we test each set Dpi ∈ Dtest by selecting one dialogue and training with all the others. To elaborate, suppose we are testing Ut ∈ D_pi then we first fine-tuning using all the dialogues in D_pi\Ut and then test on Ut. This process is repeated for all the dialogues in D_pi
Is this what you mean by
Since our meta net is optimized to fast adapt to new tasks(persona), in the test time, we need to do gradient decent for fast adaptation (a.k.a., few shot learning), which is done by line 190.
If so, I was wondering what would be the impact on the final model performance if I exclude Meta-Evaluation
part from PAML.py
. What would be the corresponding impacts ?
Finally, I don't fully understand what is the purpose of rest(meta_net.load_state_dict({ name: weights_original[name] for name in weights_original })
) during training (156) and testing (194)
Thank you
Question: If so, I was wondering what would be the impact on the final model performance if I exclude Meta-Evaluation part from PAML.py. What would be the corresponding impacts ? Answer: The MAML trained model is a initial weights that can easily adapt to certain tasks(persona). If we don't do the fast adaptation during testing, the model will not be personalized.
Question: Finally, I don't fully understand what is the purpose of rest(meta_net.load_state_dict({ name: weights_original[name] for name in weights_original })) during training (156) and testing (194) Answer: weights_original is the weight of meta model, which optimized by only second order gradients. (Here we use the first order approximation) This is the original paper of MAML https://arxiv.org/pdf/1703.03400.pdf which detailed describ the concept.
Ok Thanks !
If you do not mind, I would like to ask few more questions
Q1 What is the difference between fast net and meta net ? How did you name each in the paper ?
Q2 In the original paper, there are two updates (one with learning rate alpha and the other with learning rate beta) which I think correspond to the line 150 and 170, respectively. I am still unsure with the update during test (line 190) Test is "normally" for printing acc, ppl, loss etc rather than taking extra training / weight updating procedures. Also the model has already trained for various personas during lines 150 and 170. I would be very much appreciated if you could give further explanation
Thanks
@zlinao Any further kind reply would be extremely helpful
We would like to refer you to the MAML paper https://arxiv.org/pdf/1703.03400.pdf to have better understanding why extra update is needed in test time. And you can consider learning a new persona as learning a new task.
Hello Thanks for great work !
In MAML.py line 172 the
Meta-Evaluation
startsMy question is that would this part have any impact on the performance of the model.
I think the answer is YES, because in line 190 the function
do_learning_fix_step
callsmodel.train_one_batch
in line 81 which computes grad of loss and update parameters viaoptimizer.step()
hereHowever, the grad update and back propagation procedures had already been taken (i) for train batch at line 150 and (ii) for valid batch at line 170 and hence I believe that there should be no extra grad update and weight update
Therefore I think the answer to my question should be NO, implying that there should be no more grad update during
Meta-Evaluation
. If your intention forMeta-Evaluation
was to print eval loss and ppl, and also save model, then i think line 190 is inappropriate and needs to be fixed such as giving atrain=False
flag fortrain_one_batch
function at least duringMeta-Evaluation
Thanks in advance