dragen1860 / MAML-Pytorch

Elegant PyTorch implementation of paper Model-Agnostic Meta-Learning (MAML)
MIT License
2.31k stars 420 forks source link

Does the hessian really gets computed? #59

Open tanghl1994 opened 3 years ago

tanghl1994 commented 3 years ago

I find that for the inner loop update, the gradients are being computed without keeping the create_graph=True, this means that the hessian was not essentially applied?

zou3519 commented 3 years ago

I also noticed this; I tried setting create_graph=True in the inner loop update and that makes it so that the omniglot example training doesn't seem to converge anymore:

step: 0         training acc: [0.19833333 0.26666667 0.44708333 0.48916667 0.49041667 0.4925    ]       loss: 1.762630581855774
Test acc: [0.199  0.2664 0.3909 0.4255 0.43   0.4312 0.432  0.4326 0.433  0.4338
 0.434 ]
step: 50        training acc: [0.19       0.37708333 0.44958333 0.47208333 0.4775     0.48      ]       loss: 1.7477717399597168
step: 100       training acc: [0.19708333 0.3825     0.47833333 0.49083333 0.49708333 0.49791667]       loss: 1.6532983779907227
step: 150       training acc: [0.19208333 0.42416667 0.51958333 0.53291667 0.53583333 0.53625   ]       loss: 1.5089343786239624
step: 200       training acc: [0.19958333 0.40291667 0.49875    0.50916667 0.50916667 0.50916667]       loss: 1.5887075662612915
step: 250       training acc: [0.20791667 0.40416667 0.50458333 0.51625    0.51958333 0.52083333]       loss: 1.6059422492980957
step: 300       training acc: [0.19916667 0.40708333 0.49083333 0.51333333 0.51291667 0.5125    ]       loss: 1.595080018043518
step: 350       training acc: [0.20375    0.425      0.49541667 0.5125     0.51416667 0.5125    ]       loss: 1.563471794128418
step: 400       training acc: [0.17916667 0.4075     0.48083333 0.48541667 0.48625    0.48916667]       loss: 1.6046133041381836
step: 450       training acc: [0.21833333 0.42666667 0.47583333 0.49875    0.49958333 0.50041667]       loss: 1.5523645877838135
step: 500       training acc: [0.21416667 0.43875    0.51541667 0.52333333 0.52583333 0.52875   ]       loss: 1.5126579999923706
Test acc: [0.202  0.4236 0.48   0.491  0.4927 0.4934 0.4937 0.4941 0.4941 0.4941
 0.4944]
step: 550       training acc: [0.21208333 0.45916667 0.5175     0.52416667 0.52208333 0.52333333]       loss: 1.5570530891418457
step: 600       training acc: [0.20625    0.45833333 0.52875    0.54208333 0.54208333 0.54291667]       loss: 1.4171380996704102
step: 650       training acc: [0.18583333 0.41583333 0.47875    0.48541667 0.48458333 0.48583333]       loss: 1.655295491218567
step: 700       training acc: [0.18875    0.44875    0.49875    0.50458333 0.505      0.50541667]       loss: 1.553985357284546
step: 750       training acc: [0.17875    0.45708333 0.52416667 0.53083333 0.52541667 0.52708333]       loss: 1.5234225988388062
step: 800       training acc: [0.20333333 0.4775     0.49333333 0.50291667 0.50541667 0.50666667]       loss: 1.5525965690612793
step: 850       training acc: [0.19291667 0.44916667 0.5175     0.52125    0.52166667 0.5225    ]       loss: 1.5418503284454346
step: 900       training acc: [0.21458333 0.44583333 0.50583333 0.51458333 0.51541667 0.51666667]       loss: 1.5525331497192383
step: 950       training acc: [0.19708333 0.41625    0.46333333 0.47291667 0.47541667 0.47666667]       loss: 1.6665573120117188
step: 1000      training acc: [0.195      0.44083333 0.50083333 0.51583333 0.51958333 0.52208333]       loss: 1.539425253868103
Test acc: [0.202  0.4248 0.464  0.4707 0.4714 0.4722 0.4727 0.473  0.4734 0.4734
 0.4736]
step: 1050      training acc: [0.19375    0.43208333 0.48416667 0.48833333 0.49       0.49      ]       loss: 1.6806050539016724
step: 1100      training acc: [0.18833333 0.45666667 0.50041667 0.50333333 0.50208333 0.50166667]       loss: 1.574751853942871
step: 1150      training acc: [0.18791667 0.44083333 0.48125    0.49041667 0.49083333 0.49208333]       loss: 1.6919797658920288
step: 1200      training acc: [0.18166667 0.43541667 0.47833333 0.49458333 0.49625    0.49708333]       loss: 1.628588318824768
step: 1250      training acc: [0.21       0.43916667 0.47333333 0.4775     0.4775     0.47708333]       loss: 1.7136387825012207
step: 1300      training acc: [0.20458333 0.45958333 0.5025     0.50833333 0.50666667 0.50625   ]       loss: 1.6481508016586304
step: 1350      training acc: [0.1875     0.45166667 0.46916667 0.48166667 0.48375    0.48541667]       loss: 1.6440492868423462
step: 1400      training acc: [0.18916667 0.43875    0.47       0.47666667 0.47875    0.47875   ]       loss: 1.7161211967468262
step: 1450      training acc: [0.20833333 0.43416667 0.46958333 0.47166667 0.47125    0.47208333]       loss: 1.7655547857284546
step: 1500      training acc: [0.21583333 0.45041667 0.49458333 0.49666667 0.49583333 0.4975    ]       loss: 1.6131330728530884
Test acc: [0.1959 0.4194 0.4534 0.458  0.4595 0.4602 0.4607 0.461  0.4614 0.4617
 0.462 ]