AntreasAntoniou / HowToTrainYourMAMLPytorch

The original code for the paper "How to train your MAML" along with a replication of the original "Model Agnostic Meta Learning" (MAML) paper in Pytorch.
https://arxiv.org/abs/1810.09502
Other
759 stars 137 forks source link

Saved model consistency #11

Closed namsan96 closed 5 years ago

namsan96 commented 5 years ago

Hello,

I found reloaded model performance is different with the one that is printed during training.

Specifically, after I train a model using same code and configuration on your repo, I ran a test code written by simply changing maml_system.run_experiment() of train_maml_system.py to maml_system.evaluated_test_set_using_the_best_model(5) with the same configuration file.

However, the performance for this code seems different with the one that is computed and printed during training. Moreover, it prints arbitrary performance levels for every run. (ex. 90%, 5%, 30% ...)

Could you give any idea for this situation?

Thank you!

AntreasAntoniou commented 5 years ago

This should work. I can confirm that recently I did something very similar to what you did and it was successful. What does the code output in the test performance file?

namsan96 commented 5 years ago

Thanks for your response.

I reproduced same problem for several experiments, and here is the one for maml++ omniglot 20way 1shot setting. (omniglot_maml++_omniglot_20_way_1_shot_maml_0.json)

Top 5 epochs and accuracies were printed as [78 84 90 71 96] [0.97347973 0.97271959 0.97255068 0.97238176 0.97212838] at the end of the training. And the final ensemble accuracy was {'test_accuracy_std': 0.17303913385957453, 'test_accuracy_mean': 0.9691028225806452}.

When I run test, every states seem to be loaded correctly (list of parameters, training histories ...) but test accuracy is arbitrary every time.

Here are several test performances I've got

test_accuracy_mean,test_accuracy_std
0.8239415322580645,0.38086990387571623
test_accuracy_std,test_accuracy_mean                                                                                                                      
0.2655746637361477,0.07636088709677419
test_accuracy_mean,test_accuracy_std                                                                                                                      
0.9506048387096774,0.21669166880945281

I think this can be related to #9.

AntreasAntoniou commented 5 years ago

The latest commit should resolve this problem. I tested it multiple times locally and ensured that reloading weights and testing them produces consistent outputs.

AntreasAntoniou commented 5 years ago

I assume this has been fixed.