Closed david-berthelot closed 3 years ago
I am not able to reproduce this. Here is what I get:
loss 0.6504083 loss 0.21322423 loss 0.16727123 loss 0.1530549 loss 0.13344146 loss 0.12683934 loss 0.11996455 loss 0.11438958 loss 0.10440791 loss 0.10162807 model accuracy 0.9919 Small weight ratio on layer (Sequential)0.w 0.057499997 Small weight ratio on layer (Sequential)2.w 0.54701173 Small weight ratio on layer (Sequential)5.w 0.73606443 Small weight ratio on layer (Sequential)7.w 0.7734863 Small weight ratio on layer (Sequential)10.w 0.0828125
Just upgrade jax etc... and I can't seem to see the issue anymore. Closing.
When I tried rerunning it, the bigger model does not converge:
https://github.com/google/objax/blob/master/examples/tutorials/mnist-tutorial.ipynb