juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

Memory usage is too high #11

Closed juliuskunze closed 4 years ago

juliuskunze commented 4 years ago

Including a Conv into the mnist example

import time

import jax.numpy as np
import numpy.random as npr
from jax.random import PRNGKey

from jaxnet import Sequential, parametrized, Dense, relu, logsoftmax, optimizers, Conv, flatten

def _one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype)

def mnist():
    import tensorflow_datasets as tfds
    dataset = tfds.load("mnist:1.0.0")
    images = lambda d: np.reshape(np.float32(d['image']) / 256, (-1, 28, 28, 1))
    labels = lambda d: _one_hot(d['label'], 10)
    train = next(tfds.as_numpy(dataset['train'].shuffle(50000).batch(50000)))
    test = next(tfds.as_numpy(dataset['test'].batch(10000)))
    return images(train), labels(train), images(test), labels(test)

predict = Sequential(
    Conv(32, (5, 5)), relu, flatten,
    Dense(500), relu,
    Dense(10), logsoftmax)

results in out-of-memory on GPU colab during apply_from (init_parameters is fine).

juliuskunze commented 4 years ago

(After refactoring a lot of the JAXnet core + updating to new JAX version) this is no longer an issue.

Additionally, TensorFlow was allocating 90% of GPU memory for data loading, leaving only 10% to JAX. This was fixed in https://github.com/JuliusKunze/jaxnet/commit/4b99db6e66512ff0186062cd7dbe0d6bf8a35dbf. OOM was not reproducible even before this commit though.