Closed juliuskunze closed 4 years ago
(After refactoring a lot of the JAXnet core + updating to new JAX version) this is no longer an issue.
Additionally, TensorFlow was allocating 90% of GPU memory for data loading, leaving only 10% to JAX. This was fixed in https://github.com/JuliusKunze/jaxnet/commit/4b99db6e66512ff0186062cd7dbe0d6bf8a35dbf. OOM was not reproducible even before this commit though.
Including a
Conv
into the mnist exampleresults in out-of-memory on GPU colab during
apply_from
(init_parameters
is fine).