Open jacob-talroo opened 1 month ago
What's the performance like with JAX? Please try it.
Thanks for the suggestion. Here is our situation. Previously, we were running a BERT pretrain with TensorFlow 2.11 and TensorFlow Models (TFM) with XLA on mulitple GPUs. This took about an hour.
We are migrating to Keras 3 and testing:
The train is modified:
add_loss
. Now, we specify the loss in the compile()
.add_metric()
so we moved these also to compile()
.FYI - we removed the extra projections from the heads of the NLP trunks and now JAX is about the same performance of TF 2.
Description: While using
keras.layers.Embedding
with TensorFlow backend in Keras 3, I encountered issues with XLA optimization during multi-GPU training, leading toInvalidArgumentError
. This problem arises becausekeras.layers.Embedding
does not natively support XLA, affecting performance significantly, particularly when training models like BERT.The TensorFlow Keras implementation provides an undocumented parameter
use_one_hot_matmul
within the embedding layer that enables XLA compatibility by using matrix multiplication instead of the default lookup method. However, this is not exposed or documented in Keras, leading to suboptimal training performance on GPU.Error Message:
Expected Behavior: Enable
keras.layers.Embedding
to be XLA compatible by default when XLA is enabled, improving the training efficiency on GPU backends.Steps to Reproduce: Please refer to the provided Colab notebook here for detailed steps to reproduce this issue: Notebook Link
Environment:
Suggested Solution: Integrate the
use_one_hot_matmul
parameter into the main Keras codebase and document its usage. This should be activated conditionally based on whether XLA (or, possibly the GPU) is enabled, allowing for seamless performance improvements.Impact: This enhancement will significantly improve training speed and efficiency, particularly for complex models like BERT, which are critical in many machine learning applications.
Details: The previous implementation is available at: https://github.com/keras-team/keras/blob/v2.15.0/keras/layers/core/embedding.py#L33-L306 :
See also: