Closed neon5d closed 5 years ago
Hi @neon5d, thank you for reporting this! Sorry for the delay.
I just ran the example for 1000 epochs on colab, which worked fine. Did you use the exact code from the examples folder? Also, what version of jax/jaxlib were you using?
Closing this since a similar issue (https://github.com/JuliusKunze/jaxnet/issues/11) is now fixed.
@neon5d Please don't hesitate to reopen in case if you still have this issue with the new version of JAXnet!
I tried mnist_classifier.py, which is corresponding to mnist_classifier.py of Jax. I increased num_epochs = 1000 mnist_classifier.py of jaxnet failed with out-of-memory although mnist_classifier.py of Jax finished without any errors.