Closed jacob-talroo closed 2 months ago
I think we probably want to solve this at the Keras level not the KerasNLP ideally.
I played around with always using the one hot approach under a distribution strategy. https://github.com/keras-team/keras/compare/master...mattdangerw:embedding-fix
I think this could work, but I am not sure we would want to do it when XLA is off by default on the TF backend. So the first thing might be to look at enabling XLA with tf.distribute
.
Is there a reason training on the Jax backend doesn't work for your use case? It is likely faster, and everything is XLA compatible as there is no other option on Jax.
We have switched to the JAX backend. If there is no desire to reduce Keras 2 TF vs Keras 3 TF performance degradation, we can close this one out.
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.
Describe the bug Training BERT using Keras NLP is significantly slower due to the
keras.layers.Embedding
not being XLA compatible by default on TensorFlow GPU. This is similar to an issue reported for Keras at keras-team/keras#19809.To Reproduce You can reproduce this issue by following the steps in this Colab Notebook: Link to Notebook
Expected behavior I expect BERT training using Keras NLP on TensorFlow GPU with XLA to be optimized for performance, similar to native TensorFlow implementations.
Additional context The lack of XLA compatibility affects the training speed and efficiency on GPU, crucial for model training scalability and practical application in production environments.
See also:
Would you like to help us fix it? Yes, I am willing to contribute to resolving this issue by testing and suggesting implementations that ensure XLA compatibility.