google / objax

Apache License 2.0
769 stars 77 forks source link

MNIST tutorial bigger model does not converge #173

Closed david-berthelot closed 3 years ago

david-berthelot commented 3 years ago

When I tried rerunning it, the bigger model does not converge:

loss 79121310.0
loss 6728657000.0
loss -21779100000.0
loss 379075130000.0
loss 327561700000.0
loss -4389404000000.0
loss 17473062000000.0
loss -8496424000000.0
loss 24158176000000.0
loss 47543010000000.0
model accuracy 0.0982

https://github.com/google/objax/blob/master/examples/tutorials/mnist-tutorial.ipynb

aterzis-google commented 3 years ago

I am not able to reproduce this. Here is what I get:

loss 0.6504083 loss 0.21322423 loss 0.16727123 loss 0.1530549 loss 0.13344146 loss 0.12683934 loss 0.11996455 loss 0.11438958 loss 0.10440791 loss 0.10162807 model accuracy 0.9919 Small weight ratio on layer (Sequential)0.w 0.057499997 Small weight ratio on layer (Sequential)2.w 0.54701173 Small weight ratio on layer (Sequential)5.w 0.73606443 Small weight ratio on layer (Sequential)7.w 0.7734863 Small weight ratio on layer (Sequential)10.w 0.0828125

david-berthelot commented 3 years ago

Just upgrade jax etc... and I can't seem to see the issue anymore. Closing.