keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.92k stars 19.45k forks source link

SparseCategoricalCrossentropy consumes heavy GPU RAM on over 10K categories. #20005

Closed TrentaIcedCoffee closed 3 months ago

TrentaIcedCoffee commented 3 months ago
x = tf.random.normal([16, 1024, 50_000])
y = tf.random.uniform([16, 1024], minval=0, maxval=50_000-1, dtype=tf.int32)

# These sparse categorical cross-entropy take 15G GPU RAM.
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y, x)
tf.keras.losses.sparse_categorical_crossentropy(y, x)

Why? And what loss_fn would you suggest when working on large number of categories that is typical in a lanugage model?

TrentaIcedCoffee commented 3 months ago

16 x 1024 x 50K x 4bytes (float32) gives 3G, so I assume it's reasonable considering tmp variables allocated? Ahh. Guess buying more GPUs is what we do for now.

mattdangerw commented 3 months ago

I believe most implementations will inflate the labels via a one hot encoding before computing the loss with the logits. So that's 3gb each for x, one-hotted y, without any temp allocations? So yeah this could very well be reasonable.

Also, be aware that by default TensorFlow can squat on all GPU memory, which can make profiling via, say, some nvidia tooling difficult. https://www.tensorflow.org/guide/gpu

As far as I can tell, on the TensorFlow backend we just delegate to a tf op here. So probably if there is any bug here we should probably open on the TensorFlow github.

https://github.com/keras-team/keras/blob/902f9da309fbf6318b20c9fe33d53b5ab4938644/keras/src/backend/tensorflow/nn.py#L674-L677

mattdangerw commented 3 months ago

I will close this for now, but if you are seeing evidence that GPU usage is much higher for Keras in particular than tf.nn.sparse_softmax_cross_entropy_with_logits, please go ahead and reopen!

google-ml-butler[bot] commented 3 months ago

Are you satisfied with the resolution of your issue? Yes No