juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

There are some memory leak #7

Closed neon5d closed 5 years ago

neon5d commented 5 years ago

I tried mnist_classifier.py, which is corresponding to mnist_classifier.py of Jax. I increased num_epochs = 1000 mnist_classifier.py of jaxnet failed with out-of-memory although mnist_classifier.py of Jax finished without any errors.

juliuskunze commented 5 years ago

Hi @neon5d, thank you for reporting this! Sorry for the delay.

I just ran the example for 1000 epochs on colab, which worked fine. Did you use the exact code from the examples folder? Also, what version of jax/jaxlib were you using?

juliuskunze commented 5 years ago

Closing this since a similar issue (https://github.com/JuliusKunze/jaxnet/issues/11) is now fixed.

@neon5d Please don't hesitate to reopen in case if you still have this issue with the new version of JAXnet!