Open tanghl1994 opened 3 years ago
I also noticed this; I tried setting create_graph=True
in the inner loop update and that makes it so that the omniglot example training doesn't seem to converge anymore:
step: 0 training acc: [0.19833333 0.26666667 0.44708333 0.48916667 0.49041667 0.4925 ] loss: 1.762630581855774
Test acc: [0.199 0.2664 0.3909 0.4255 0.43 0.4312 0.432 0.4326 0.433 0.4338
0.434 ]
step: 50 training acc: [0.19 0.37708333 0.44958333 0.47208333 0.4775 0.48 ] loss: 1.7477717399597168
step: 100 training acc: [0.19708333 0.3825 0.47833333 0.49083333 0.49708333 0.49791667] loss: 1.6532983779907227
step: 150 training acc: [0.19208333 0.42416667 0.51958333 0.53291667 0.53583333 0.53625 ] loss: 1.5089343786239624
step: 200 training acc: [0.19958333 0.40291667 0.49875 0.50916667 0.50916667 0.50916667] loss: 1.5887075662612915
step: 250 training acc: [0.20791667 0.40416667 0.50458333 0.51625 0.51958333 0.52083333] loss: 1.6059422492980957
step: 300 training acc: [0.19916667 0.40708333 0.49083333 0.51333333 0.51291667 0.5125 ] loss: 1.595080018043518
step: 350 training acc: [0.20375 0.425 0.49541667 0.5125 0.51416667 0.5125 ] loss: 1.563471794128418
step: 400 training acc: [0.17916667 0.4075 0.48083333 0.48541667 0.48625 0.48916667] loss: 1.6046133041381836
step: 450 training acc: [0.21833333 0.42666667 0.47583333 0.49875 0.49958333 0.50041667] loss: 1.5523645877838135
step: 500 training acc: [0.21416667 0.43875 0.51541667 0.52333333 0.52583333 0.52875 ] loss: 1.5126579999923706
Test acc: [0.202 0.4236 0.48 0.491 0.4927 0.4934 0.4937 0.4941 0.4941 0.4941
0.4944]
step: 550 training acc: [0.21208333 0.45916667 0.5175 0.52416667 0.52208333 0.52333333] loss: 1.5570530891418457
step: 600 training acc: [0.20625 0.45833333 0.52875 0.54208333 0.54208333 0.54291667] loss: 1.4171380996704102
step: 650 training acc: [0.18583333 0.41583333 0.47875 0.48541667 0.48458333 0.48583333] loss: 1.655295491218567
step: 700 training acc: [0.18875 0.44875 0.49875 0.50458333 0.505 0.50541667] loss: 1.553985357284546
step: 750 training acc: [0.17875 0.45708333 0.52416667 0.53083333 0.52541667 0.52708333] loss: 1.5234225988388062
step: 800 training acc: [0.20333333 0.4775 0.49333333 0.50291667 0.50541667 0.50666667] loss: 1.5525965690612793
step: 850 training acc: [0.19291667 0.44916667 0.5175 0.52125 0.52166667 0.5225 ] loss: 1.5418503284454346
step: 900 training acc: [0.21458333 0.44583333 0.50583333 0.51458333 0.51541667 0.51666667] loss: 1.5525331497192383
step: 950 training acc: [0.19708333 0.41625 0.46333333 0.47291667 0.47541667 0.47666667] loss: 1.6665573120117188
step: 1000 training acc: [0.195 0.44083333 0.50083333 0.51583333 0.51958333 0.52208333] loss: 1.539425253868103
Test acc: [0.202 0.4248 0.464 0.4707 0.4714 0.4722 0.4727 0.473 0.4734 0.4734
0.4736]
step: 1050 training acc: [0.19375 0.43208333 0.48416667 0.48833333 0.49 0.49 ] loss: 1.6806050539016724
step: 1100 training acc: [0.18833333 0.45666667 0.50041667 0.50333333 0.50208333 0.50166667] loss: 1.574751853942871
step: 1150 training acc: [0.18791667 0.44083333 0.48125 0.49041667 0.49083333 0.49208333] loss: 1.6919797658920288
step: 1200 training acc: [0.18166667 0.43541667 0.47833333 0.49458333 0.49625 0.49708333] loss: 1.628588318824768
step: 1250 training acc: [0.21 0.43916667 0.47333333 0.4775 0.4775 0.47708333] loss: 1.7136387825012207
step: 1300 training acc: [0.20458333 0.45958333 0.5025 0.50833333 0.50666667 0.50625 ] loss: 1.6481508016586304
step: 1350 training acc: [0.1875 0.45166667 0.46916667 0.48166667 0.48375 0.48541667] loss: 1.6440492868423462
step: 1400 training acc: [0.18916667 0.43875 0.47 0.47666667 0.47875 0.47875 ] loss: 1.7161211967468262
step: 1450 training acc: [0.20833333 0.43416667 0.46958333 0.47166667 0.47125 0.47208333] loss: 1.7655547857284546
step: 1500 training acc: [0.21583333 0.45041667 0.49458333 0.49666667 0.49583333 0.4975 ] loss: 1.6131330728530884
Test acc: [0.1959 0.4194 0.4534 0.458 0.4595 0.4602 0.4607 0.461 0.4614 0.4617
0.462 ]
I find that for the inner loop update, the gradients are being computed without keeping the create_graph=True, this means that the hessian was not essentially applied?