keras-team / keras-hub

Pretrained model hub for Keras 3
Apache License 2.0
780 stars 237 forks source link

Enable XLA Compatibility for Pretraining BERT with Keras NLP on TensorFlow GPU #1661

Closed jacob-talroo closed 2 months ago

jacob-talroo commented 4 months ago

Describe the bug Training BERT using Keras NLP is significantly slower due to the keras.layers.Embedding not being XLA compatible by default on TensorFlow GPU. This is similar to an issue reported for Keras at keras-team/keras#19809.

To Reproduce You can reproduce this issue by following the steps in this Colab Notebook: Link to Notebook

Expected behavior I expect BERT training using Keras NLP on TensorFlow GPU with XLA to be optimized for performance, similar to native TensorFlow implementations.

Additional context The lack of XLA compatibility affects the training speed and efficiency on GPU, crucial for model training scalability and practical application in production environments.

See also:

Would you like to help us fix it? Yes, I am willing to contribute to resolving this issue by testing and suggesting implementations that ensure XLA compatibility.

mattdangerw commented 3 months ago

I think we probably want to solve this at the Keras level not the KerasNLP ideally.

I played around with always using the one hot approach under a distribution strategy. https://github.com/keras-team/keras/compare/master...mattdangerw:embedding-fix

I think this could work, but I am not sure we would want to do it when XLA is off by default on the TF backend. So the first thing might be to look at enabling XLA with tf.distribute.

mattdangerw commented 3 months ago

Is there a reason training on the Jax backend doesn't work for your use case? It is likely faster, and everything is XLA compatible as there is no other option on Jax.

jacob-talroo commented 2 months ago

We have switched to the JAX backend. If there is no desire to reduce Keras 2 TF vs Keras 3 TF performance degradation, we can close this one out.

github-actions[bot] commented 2 months ago

This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.