huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.93k stars 26.79k forks source link

TF2 DeBERTaV2 runs super slow on TPUs #18239

Closed WissamAntoun closed 2 years ago

WissamAntoun commented 2 years ago

System Info

latest version of transformers, Colab TPU, tensorflow 2

Who can help?

@kamalkraj @Rocketknight1 @BigBird01

Information

Tasks

Reproduction

It's currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue

Expected behavior

I've been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/ I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.

On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it's fine).

Ok, now the problem.... on TPU i get 20 sentences per second

I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525

I ran TPU profiling and this is the output: image

GatherV2 takes most of the time: image

zoomed in pictures of the fast ops image

Also, I'm not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.

Rocketknight1 commented 2 years ago

Hi @WissamAntoun, this is an interesting issue! I honestly have no idea what the cause could be, but the fact that it highlights that function is interesting. The reason is that the DeBERTa code was ported from PyTorch, and so we wrote our own implementation of take_along_axis because TF didn't have one. One thing to try would be to edit the code to use tf.experimental.numpy.take_along_axis instead of that function. If that doesn't work then we might have to see if we can do things in a different, more performant way.

Also, just in case XLA compilation is the issue, have you tried using jit_compile=True in compile() when running DeBERTa on GPU? If that also causes performance degradation then the problem is caused by XLA and not TPUs, and we can investigate from there.

Rocketknight1 commented 2 years ago

Also cc @sanchit-gandhi because I'm not a TPU expert - don't worry about investigating this deeply, but if anything comes to mind when you read it, let me know!

WissamAntoun commented 2 years ago

@Rocketknight1 I read all the discussions that you had with Kamal about the torch.gather and take_along_axis .

On GPUs I already enabled XLA via tf.config.optimizer.set_jit and via TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" but I was reading that this isn't the optimal way to do it, so I'm now trying the jit_compile=True and will report back.

Also I just finished testing tf.experimental.numpy.take_along_axis, on GPUs it improved performance by ~10% yet on TPUs I still have the same issue. I will also test the jit_compile on TPUs but I don't think it will solve anything.

Thanks a lot for the replies and for the effort you put in convert the pytorch code into TF

WissamAntoun commented 2 years ago

runnig the training with jit_compile=True on GPU revealed a new bug. Then it is now an XLA/JIT issue not a TPU one

View log dump

```md 2022-07-21 23:36:18.107830: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at bcast_ops.cc:50 : INVALID_ARGUMENT: Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs` with op BroadcastArgs must be a compile-time constant. XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator. Stack trace for op definition: File "run_pretraining.py", line 204, in config = main(start_time) File "run_pretraining.py", line 184, in main trained_model = run_customized_training_loop( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop train_steps_strategy( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy if num_grad_accumulates != 1: File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy for step_idx in tf.range(steps * num_grad_accumulates): File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy strategy.run(_forward, args=(next(iterator),)) File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward loss, model_outputs = model(inputs, is_training=True) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call if config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call mlm_output = self._get_masked_lm_output( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output if self._config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output outputs = generator( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call outputs = self.deberta( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call encoder_outputs = self.encoder( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call for i, layer_module in enumerate(self.layer): File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call layer_outputs = layer_module( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call attention_outputs = self.attention( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call self_outputs = self.self( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call if self.relative_attention: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call rel_att = self.disentangled_att_bias( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias if "c2p" in self.pos_att_type: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias c2p_att = tnp.take_along_axis( 2022-07-21 23:36:18.184105: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at xla_ops.cc:248 : INVALID_ARGUMENT: Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs` with op BroadcastArgs must be a compile-time constant. XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator. Stack trace for op definition: File "run_pretraining.py", line 204, in config = main(start_time) File "run_pretraining.py", line 184, in main trained_model = run_customized_training_loop( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop train_steps_strategy( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy if num_grad_accumulates != 1: File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy for step_idx in tf.range(steps * num_grad_accumulates): File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy strategy.run(_forward, args=(next(iterator),)) File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward loss, model_outputs = model(inputs, is_training=True) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call if config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call mlm_output = self._get_masked_lm_output( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output if self._config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output outputs = generator( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call outputs = self.deberta( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call encoder_outputs = self.encoder( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call for i, layer_module in enumerate(self.layer): File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call layer_outputs = layer_module( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call attention_outputs = self.attention( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call self_outputs = self.self( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call if self.relative_attention: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call rel_att = self.disentangled_att_bias( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias if "c2p" in self.pos_att_type: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias c2p_att = tnp.take_along_axis( [[{{node pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs}}]] Traceback (most recent call last): File "run_pretraining.py", line 204, in config = main(start_time) File "run_pretraining.py", line 184, in main trained_model = run_customized_training_loop( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop train_steps_strategy( File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler raise e.with_traceback(filtered_tb) from None File "/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error: Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs` with op BroadcastArgs must be a compile-time constant. XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator. Stack trace for op definition: File "run_pretraining.py", line 204, in config = main(start_time) File "run_pretraining.py", line 184, in main trained_model = run_customized_training_loop( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 675, in run_customized_training_loop train_steps_strategy( File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 407, in train_steps_strategy if num_grad_accumulates != 1: File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 408, in train_steps_strategy for step_idx in tf.range(steps * num_grad_accumulates): File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 410, in train_steps_strategy strategy.run(_forward, args=(next(iterator),)) File "/workspaces/nv-deberta-tf2/electra/model_training_utils.py", line 324, in _forward loss, model_outputs = model(inputs, is_training=True) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2491, in call if config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2496, in call mlm_output = self._get_masked_lm_output( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2541, in _get_masked_lm_output if self._config.uniform_generator: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 2550, in _get_masked_lm_output outputs = generator( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1880, in call outputs = self.deberta( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_utils.py", line 1872, in run_call_with_unpacked_inputs ) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1617, in call encoder_outputs = self.encoder( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 527, in call for i, layer_module in enumerate(self.layer): File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 532, in call layer_outputs = layer_module( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 317, in call attention_outputs = self.attention( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 226, in call self_outputs = self.self( File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 64, in error_handler return fn(*args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py", line 1096, in __call__ outputs = call_fn(inputs, *args, **kwargs) File "/usr/local/lib/python3.8/dist-packages/keras/utils/traceback_utils.py", line 92, in error_handler return fn(*args, **kwargs) File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 876, in call if self.relative_attention: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 878, in call rel_att = self.disentangled_att_bias( File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 991, in disentangled_att_bias if "c2p" in self.pos_att_type: File "/workspaces/nv-deberta-tf2/electra/modeling_tf_deberta_v2.py", line 1012, in disentangled_att_bias c2p_att = tnp.take_along_axis( [[{{node pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs}}]] [[while/body/_1/while/StatefulPartitionedCall]] [Op:__inference_train_steps_strategy_177980] ```

Rocketknight1 commented 2 years ago

@WissamAntoun Confirmed reproduction of the issue here. Our TF DeBERTa implementation seems to have issues with XLA - I'm investigating now.

Rocketknight1 commented 2 years ago

@WissamAntoun We have a potential fix - I've confirmed that I can compile microsoft/deberta-v3-small with XLA on my local machine. Can you try installing this branch and let me know if this fixes the problem for you? You can use pip install git+https://github.com/huggingface/transformers.git@deberta-xla-fixes

WissamAntoun commented 2 years ago

I confirm it works on GPUs with XLA, and I got ~20% improved speedup. I'm still testing now on TPUs, will let you know ASAP

WissamAntoun commented 2 years ago

Weirdly enough TPUs didn't seem to care about the changes 😅 even after we removed all the if branches

Rocketknight1 commented 2 years ago

Hmm. Can you check that you don't get the slowdown if you switch the model to another model, like BERT or ELECTRA, while keeping all of the other code the same (especially data loading)? I know the profiling indicates that the GatherV2 is the problem, but I'm a little suspicious!

WissamAntoun commented 2 years ago

I tried disabling relative_attention in deberta, which makes the model a regular BERT, and the performance improved 40x 😅

Rocketknight1 commented 2 years ago

@WissamAntoun So the issue really is in that gather! That's extremely interesting - with the simplified code, it's just a single call to tf.gather, but perhaps the batch_dims argument is not handled elegantly on TPU, or XLA converts it in a way that doesn't run well on TPU.

Is it possible that some kind of memory spill is occurring? Can you try lowering your batch size and increasing steps_per_execution?

If that isn't it, then I have no idea - maybe there's some way to rewrite the gather, but I don't really know what to try!

WissamAntoun commented 2 years ago

@Rocketknight1 I tried your suggestions without any success, sadly!

Then I tried replacing the whole take_along_axis function with tf.gather(..,...,batch_dims=2) which is equivalent, according to this test I made. GPU still runs fine, TPU still has the same issue 😔.

I also ran out of ideas to try, now I'm just waiting for the TPU gods 😅

View code

```python #%% import tensorflow as tf #%% x_shape = [32, 128, 512] indices_shape = [32, 128, 128] x = tf.random.uniform(shape=x_shape) indices = tf.random.uniform(shape=indices_shape, minval=1, maxval=128, dtype=tf.int32) #%% flat_x = tf.reshape(x, (-1, x_shape[-1])) print(flat_x.shape) # (4096, 512) flat_indices = tf.reshape(indices, (-1, indices_shape[-1])) print(flat_indices.shape) # (4096, 128) #%% gathered = tf.gather( params=flat_x, indices=flat_indices, batch_dims=1, validate_indices=None ) print(gathered.shape) # (4096, 128) gathered_reshaped = tf.reshape(gathered, indices.shape) print(gathered_reshaped.shape) # ( 32, 128, 128) # %% gathered2 = tf.gather(params=x, indices=indices, batch_dims=2, validate_indices=None) print(gathered2.shape) # (32, 128, 128) # %% tf.assert_equal(gathered2, gathered_reshaped) # passes # %% ```

Rocketknight1 commented 2 years ago

I'm clueless in that case - @patrickvonplaten @sanchit-gandhi do you have any idea why a gather or take_along_axis op which is performant on GPU and compiles with XLA would become a huge bottleneck on TPU?

sanchit-gandhi commented 2 years ago

In our JAX BLOOM experiments, we experienced significant improvements in performance by changing how we indexed. Swapping scatter ops for one-host broadcasts, we obtained 3-4x speed-ups in practice. The logic is largely lifted from T5X: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284

I wonder if applying similar logic here and swapping the gather op to one-hot indexing might help?

WissamAntoun commented 2 years ago

DO you mean something to BERT one-hot embeddings ?https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/on_device_embedding.py#L79

sanchit-gandhi commented 2 years ago

Simply modifying the bottleneck function: https://github.com/huggingface/transformers/blob/f4e172716b91b477ce3cddc9a253094b7121a4b8/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525 To use one_hot encodings as opposed to a gather op. The example you've liked looks like the right idea! Worth a try IMO!

WissamAntoun commented 2 years ago

I tried this, although I'm not sure if it's the best implementation

def take_along_axis(x, indices):

    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512]

    # [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1]
    gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1)
    return gathered

It improved the speed from 20 seq/s to 110 seq/s. For reference, regular ELECTRA/BERT got ~800 seq/s.

Now it's the reshape and squeeze operations that are "wasting" time:

image

WissamAntoun commented 2 years ago

@sanchit-gandhi is there a better implementation than mine, without expand_dims or squeeze since these are unfavorable operations on TPUs

sanchit-gandhi commented 2 years ago

Nice! A 5x speed up is a good start. If we can get another 5x we'll be in business. Thanks for linking the Tensorboard profile! Super helpful in identifying bottlenecks like these 🙏

Interesting to see the expand_dims and squeeze are now accruing large amounts of runtime. I'm not a TF user (it's mainly JAX on TPU for me!), so I'm not up to speed with implementation details, but my impression from the profile is that the shapes are unfavourable for XLA. Perhaps you could have a play around and see whether changing the tensor shapes / choice of TF ops have any effect? It's been the case for me in the past that using tensors of different shape can give big speed-ups. Is there a repo you could reference for XLA optimised TF code? For JAX, we usually look to the T5X repo when deciding on tensor shapes and trying out 'hacks' like these: https://github.com/google-research/t5x/tree/main/t5x

cc @Rocketknight1 who's more up to speed in the TF sphere!

sanchit-gandhi commented 2 years ago

Hey @WissamAntoun! Any luck with this? Maybe also worth trying https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/take_along_axis

WissamAntoun commented 2 years ago

Hey @sanchit-gandhi , I have already tried the exp. numpy function with no improvement at all compared to gather with batch_dims=2.

I also tried going up to sequence length of 512, I got the exact same speedup but it is still much slower than expected (around 20 seq/s for sentence length 512). I also changed batch sizes with no effect at all

sanchit-gandhi commented 2 years ago

Okay probably worth sticking with the one-hot encoding hack then, seems most promising! I'm not a TF user so can't comment on the exact implementations changes you could make with the expand_dims or squeeze ops. Perhaps @gante could take a look here with his experience using TF and XLA?

gante commented 2 years ago

Now it's the reshape and squeeze operations that are "wasting" time

Interesting -- I spent some time with TPU profiling on a different application (TF text generation with a myriad of models), and found that those two operations were part of the bottleneck (along XLA's dynamic_update_slice). They accounted for 50-70% of the execution time. Do you know if it is also a bottleneck for FLAX, @sanchit-gandhi (e.g. the cache updates here)?

sanchit-gandhi commented 2 years ago

For JAX BLOOM we couldn't even compile the 176B parameter model with the naive implementation of concatenate_to_cache, yet alone benchmark which operations consumed the bulk of the execution time! We swapped it for this more efficient implementation (with one-hot encodings etc): https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L119 Coincidentally, we've just run the JAX profiler for this implementation and are going through the traceback it with some of the Google JAX guys later today. Will report back on how performance fares!

WissamAntoun commented 2 years ago
def take_along_axis(x, indices):

    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512]

    # [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1]
    gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1)
    return gathered

@gante Do you think the one-hot trick can be done without the expands_dims and squeeze, maybe then we can just dodge the whole problem

gante commented 2 years ago

@sanchit-gandhi that's interesting! I'd be interested in knowing the pro tips for XLA (which should also apply to TF)

@WissamAntoun Yeah, we can rework it with tf.einsum magic, assuming the operation can be rewritten with Einstein notation -- in this case, it is possible! Check the implementation below, give it a try, and let us know if it helped with speed on a TPU (my debug runs confirmed that they are numerically equivalent)

def take_along_axis(x, indices):
    # [B, S, P] -> [B, S, P, D]
    one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)

    # if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
    # grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
    gathered = tf.einsum('ijkl,ijl->ijk', one_hot_indices, x)

    return gathered
WissamAntoun commented 2 years ago

@gante I tested the tf.einsum implementation. It gave me the same performance as the one_hot trick, which is about ~120 seq/second. I tried it with different batch sizes but still it didn't change much.

This is a screenshot of the profiler: Screenshot 2022-08-03 155826

gante commented 2 years ago

I'm out of suggestions :( I suspect this is a good question for Google's XLA and TPU teams -- the problem is probably at a compiler/hardware level.

WissamAntoun commented 2 years ago

Yeah this is a weird and unexpected bug. Do you know someone we can get in contact with from Google's XLA or TPU team?

And thanks a lot for the efforts you guys put into this issue!

gante commented 2 years ago

@sanchit-gandhi do you know a good point of contact for TPU problems?

stefan-it commented 2 years ago

Ping @JackCaoG for help :)

JackCaoG commented 2 years ago

Thanks, I will try to take a look or finding someone from my team to help.

nvm, this is tf2, I only knows pt/xla lol

sanchit-gandhi commented 2 years ago

@sanchit-gandhi do you know a good point of contact for TPU problems?

Only for JAX on TPU, I'll ask around and see if there is anyone who can help with TF!

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.