keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.27k stars 19.38k forks source link

Enhancement Request: XLA Compatibility for `keras.layers.Embedding #19809

Open jacob-talroo opened 1 month ago

jacob-talroo commented 1 month ago

Description: While using keras.layers.Embedding with TensorFlow backend in Keras 3, I encountered issues with XLA optimization during multi-GPU training, leading to InvalidArgumentError. This problem arises because keras.layers.Embedding does not natively support XLA, affecting performance significantly, particularly when training models like BERT.

The TensorFlow Keras implementation provides an undocumented parameter use_one_hot_matmul within the embedding layer that enables XLA compatibility by using matrix multiplication instead of the default lookup method. However, this is not exposed or documented in Keras, leading to suboptimal training performance on GPU.

Error Message:

InvalidArgumentError                      Traceback (most recent call last)

<ipython-input-43-4d6071dbb55e> in <cell line: 1>()
----> 1 model.fit(
      2   dataset,
      3   steps_per_epoch=100,
      4   epochs=2,
      5   verbose=1

1 frames

/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     51   try:
     52     ctx.ensure_initialized()
---> 53     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     54                                         inputs, attrs, num_outputs)
     55   except core._NotOkStatusException as e:

InvalidArgumentError: Graph execution error:

Detected at node adam/CollectiveGatherV2 defined at (most recent call last):
<stack traces unavailable>
Detected at node adam/CollectiveGatherV2 defined at (most recent call last):
<stack traces unavailable>
Detected unsupported operations when trying to compile graph __inference_one_step_on_data_2036[] on XLA_GPU_JIT: CollectiveGatherV2 (No registered 'CollectiveGatherV2' OpKernel for XLA_GPU_JIT devices compatible with node {{node adam/CollectiveGatherV2}}){{node adam/CollectiveGatherV2}}
The op is created at: 
File "lib/python3.10/threading.py", line 973, in _bootstrap
File "lib/python3.10/threading.py", line 1016, in _bootstrap_inner
File "local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 104, in one_step_on_data
File "local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/trainer.py", line 69, in train_step
File "local/lib/python3.10/dist-packages/keras/src/optimizers/base_optimizer.py", line 282, in apply_gradients
File "local/lib/python3.10/dist-packages/keras/src/optimizers/base_optimizer.py", line 351, in apply
File "local/lib/python3.10/dist-packages/keras/src/optimizers/base_optimizer.py", line 405, in _backend_apply_gradients
File "local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/optimizer.py", line 119, in _backend_update_step
File "local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/optimizer.py", line 129, in _distributed_tf_update_step
File "local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/optimizer.py", line 160, in _all_reduce_sum_gradients
    tf2xla conversion failed while converting __inference_one_step_on_data_2036[]. Run with TF_DUMP_GRAPH_PREFIX=/path/to/dump/dir and --vmodule=xla_compiler=2 to obtain a dump of the compiled functions.
     [[StatefulPartitionedCall]] [Op:__inference_one_step_on_iterator_2095]

Expected Behavior: Enable keras.layers.Embedding to be XLA compatible by default when XLA is enabled, improving the training efficiency on GPU backends.

Steps to Reproduce: Please refer to the provided Colab notebook here for detailed steps to reproduce this issue: Notebook Link

Environment:

Suggested Solution: Integrate the use_one_hot_matmul parameter into the main Keras codebase and document its usage. This should be activated conditionally based on whether XLA (or, possibly the GPU) is enabled, allowing for seamless performance improvements.

Impact: This enhancement will significantly improve training speed and efficiency, particularly for complex models like BERT, which are critical in many machine learning applications.

Details: The previous implementation is available at: https://github.com/keras-team/keras/blob/v2.15.0/keras/layers/core/embedding.py#L33-L306 :

        elif self._use_one_hot_matmul:
            # Note that we change the dtype of the one_hot to be same as the
            # weight tensor, since the input data are usually ints, and weights
            # are floats. The nn.embedding_lookup support ids as ints, but
            # the one_hot matmul need both inputs and weights to be same dtype.
            one_hot_data = tf.one_hot(
                inputs, depth=self.input_dim, dtype=self.dtype
            )
            out = tf.matmul(one_hot_data, self.embeddings)

See also:

fchollet commented 4 weeks ago

What's the performance like with JAX? Please try it.

jacob-talroo commented 4 weeks ago

Thanks for the suggestion. Here is our situation. Previously, we were running a BERT pretrain with TensorFlow 2.11 and TensorFlow Models (TFM) with XLA on mulitple GPUs. This took about an hour.

We are migrating to Keras 3 and testing:

The train is modified:

jacob-talroo commented 3 weeks ago

FYI - we removed the extra projections from the heads of the NLP trunks and now JAX is about the same performance of TF 2.