Closed WissamAntoun closed 2 years ago
Hi @WissamAntoun, this is an interesting issue! I honestly have no idea what the cause could be, but the fact that it highlights that function is interesting. The reason is that the DeBERTa code was ported from PyTorch, and so we wrote our own implementation of take_along_axis
because TF didn't have one. One thing to try would be to edit the code to use tf.experimental.numpy.take_along_axis
instead of that function. If that doesn't work then we might have to see if we can do things in a different, more performant way.
Also, just in case XLA compilation is the issue, have you tried using jit_compile=True
in compile()
when running DeBERTa on GPU? If that also causes performance degradation then the problem is caused by XLA and not TPUs, and we can investigate from there.
Also cc @sanchit-gandhi because I'm not a TPU expert - don't worry about investigating this deeply, but if anything comes to mind when you read it, let me know!
@Rocketknight1 I read all the discussions that you had with Kamal about the torch.gather
and take_along_axis
.
On GPUs I already enabled XLA via tf.config.optimizer.set_jit
and via TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
but I was reading that this isn't the optimal way to do it, so I'm now trying the jit_compile=True
and will report back.
Also I just finished testing tf.experimental.numpy.take_along_axis
, on GPUs it improved performance by ~10% yet on TPUs I still have the same issue. I will also test the jit_compile
on TPUs but I don't think it will solve anything.
Thanks a lot for the replies and for the effort you put in convert the pytorch code into TF
runnig the training with jit_compile=True
on GPU revealed a new bug. Then it is now an XLA/JIT issue not a TPU one
```md
2022-07-21 23:36:18.107830: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at bcast_ops.cc:50 :
INVALID_ARGUMENT:
Input 0 to node `pretraining_model/tf_deberta_v2_for_masked_lm/deberta/encoder/layer_._0/attention/self/BroadcastArgs`
with op BroadcastArgs must be a compile-time constant.
XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time.
This error means that a shape or dimension argument could not be evaluated at compile time,
usually because the value of the argument depends on a parameter to the computation, on a variable,
or on a stateful operation such as a random number generator.
Stack trace for op definition:
File "run_pretraining.py", line 204, in
@WissamAntoun Confirmed reproduction of the issue here. Our TF DeBERTa implementation seems to have issues with XLA - I'm investigating now.
@WissamAntoun We have a potential fix - I've confirmed that I can compile microsoft/deberta-v3-small
with XLA on my local machine. Can you try installing this branch and let me know if this fixes the problem for you? You can use pip install git+https://github.com/huggingface/transformers.git@deberta-xla-fixes
I confirm it works on GPUs with XLA, and I got ~20% improved speedup. I'm still testing now on TPUs, will let you know ASAP
Weirdly enough TPUs didn't seem to care about the changes 😅 even after we removed all the if branches
Hmm. Can you check that you don't get the slowdown if you switch the model to another model, like BERT or ELECTRA, while keeping all of the other code the same (especially data loading)? I know the profiling indicates that the GatherV2
is the problem, but I'm a little suspicious!
I tried disabling relative_attention
in deberta, which makes the model a regular BERT, and the performance improved 40x 😅
@WissamAntoun So the issue really is in that gather! That's extremely interesting - with the simplified code, it's just a single call to tf.gather
, but perhaps the batch_dims
argument is not handled elegantly on TPU, or XLA converts it in a way that doesn't run well on TPU.
Is it possible that some kind of memory spill is occurring? Can you try lowering your batch size and increasing steps_per_execution?
If that isn't it, then I have no idea - maybe there's some way to rewrite the gather, but I don't really know what to try!
@Rocketknight1 I tried your suggestions without any success, sadly!
Then I tried replacing the whole take_along_axis
function with tf.gather(..,...,batch_dims=2)
which is equivalent, according to this test I made. GPU still runs fine, TPU still has the same issue 😔.
I also ran out of ideas to try, now I'm just waiting for the TPU gods 😅
```python #%% import tensorflow as tf #%% x_shape = [32, 128, 512] indices_shape = [32, 128, 128] x = tf.random.uniform(shape=x_shape) indices = tf.random.uniform(shape=indices_shape, minval=1, maxval=128, dtype=tf.int32) #%% flat_x = tf.reshape(x, (-1, x_shape[-1])) print(flat_x.shape) # (4096, 512) flat_indices = tf.reshape(indices, (-1, indices_shape[-1])) print(flat_indices.shape) # (4096, 128) #%% gathered = tf.gather( params=flat_x, indices=flat_indices, batch_dims=1, validate_indices=None ) print(gathered.shape) # (4096, 128) gathered_reshaped = tf.reshape(gathered, indices.shape) print(gathered_reshaped.shape) # ( 32, 128, 128) # %% gathered2 = tf.gather(params=x, indices=indices, batch_dims=2, validate_indices=None) print(gathered2.shape) # (32, 128, 128) # %% tf.assert_equal(gathered2, gathered_reshaped) # passes # %% ```
I'm clueless in that case - @patrickvonplaten @sanchit-gandhi do you have any idea why a gather
or take_along_axis
op which is performant on GPU and compiles with XLA would become a huge bottleneck on TPU?
In our JAX BLOOM experiments, we experienced significant improvements in performance by changing how we indexed. Swapping scatter ops for one-host broadcasts, we obtained 3-4x speed-ups in practice. The logic is largely lifted from T5X: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
I wonder if applying similar logic here and swapping the gather op to one-hot indexing might help?
DO you mean something to BERT one-hot embeddings ?https://github.com/tensorflow/models/blob/master/official/nlp/modeling/layers/on_device_embedding.py#L79
Simply modifying the bottleneck function: https://github.com/huggingface/transformers/blob/f4e172716b91b477ce3cddc9a253094b7121a4b8/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525
To use one_hot
encodings as opposed to a gather
op. The example you've liked looks like the right idea! Worth a try IMO!
I tried this, although I'm not sure if it's the best implementation
def take_along_axis(x, indices):
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512]
# [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1]
gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1)
return gathered
It improved the speed from 20 seq/s to 110 seq/s. For reference, regular ELECTRA/BERT got ~800 seq/s.
Now it's the reshape and squeeze operations that are "wasting" time:
@sanchit-gandhi is there a better implementation than mine, without expand_dims
or squeeze
since these are unfavorable operations on TPUs
Nice! A 5x speed up is a good start. If we can get another 5x we'll be in business. Thanks for linking the Tensorboard profile! Super helpful in identifying bottlenecks like these 🙏
Interesting to see the expand_dims
and squeeze
are now accruing large amounts of runtime. I'm not a TF user (it's mainly JAX on TPU for me!), so I'm not up to speed with implementation details, but my impression from the profile is that the shapes are unfavourable for XLA. Perhaps you could have a play around and see whether changing the tensor shapes / choice of TF ops have any effect? It's been the case for me in the past that using tensors of different shape can give big speed-ups. Is there a repo you could reference for XLA optimised TF code? For JAX, we usually look to the T5X repo when deciding on tensor shapes and trying out 'hacks' like these: https://github.com/google-research/t5x/tree/main/t5x
cc @Rocketknight1 who's more up to speed in the TF sphere!
Hey @WissamAntoun! Any luck with this? Maybe also worth trying https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/take_along_axis
Hey @sanchit-gandhi , I have already tried the exp. numpy function with no improvement at all compared to gather
with batch_dims=2
.
I also tried going up to sequence length of 512
, I got the exact same speedup but it is still much slower than expected (around 20 seq/s for sentence length 512). I also changed batch sizes with no effect at all
Okay probably worth sticking with the one-hot encoding hack then, seems most promising! I'm not a TF user so can't comment on the exact implementations changes you could make with the expand_dims
or squeeze
ops. Perhaps @gante could take a look here with his experience using TF and XLA?
Now it's the reshape and squeeze operations that are "wasting" time
Interesting -- I spent some time with TPU profiling on a different application (TF text generation with a myriad of models), and found that those two operations were part of the bottleneck (along XLA's dynamic_update_slice
). They accounted for 50-70% of the execution time. Do you know if it is also a bottleneck for FLAX, @sanchit-gandhi (e.g. the cache updates here)?
For JAX BLOOM we couldn't even compile the 176B parameter model with the naive implementation of concatenate_to_cache
, yet alone benchmark which operations consumed the bulk of the execution time! We swapped it for this more efficient implementation (with one-hot encodings etc): https://github.com/huggingface/bloom-jax-inference/blob/2a04aa519d262729d54adef3d19d63879f81ea89/bloom_inference/modeling_bloom/modeling_bloom.py#L119
Coincidentally, we've just run the JAX profiler for this implementation and are going through the traceback it with some of the Google JAX guys later today. Will report back on how performance fares!
def take_along_axis(x, indices): one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype) # [B, S, P, D] => [B, 128, 128, 512] # [B, S, P, D] . [B, S, D, 1] = [B, S, P, 1] gathered = tf.squeeze(tf.matmul(one_hot_indices, tf.expand_dims(x, axis=-1)), axis=-1) return gathered
@gante Do you think the one-hot trick can be done without the expands_dims
and squeeze
, maybe then we can just dodge the whole problem
@sanchit-gandhi that's interesting! I'd be interested in knowing the pro tips for XLA (which should also apply to TF)
@WissamAntoun Yeah, we can rework it with tf.einsum
magic, assuming the operation can be rewritten with Einstein notation -- in this case, it is possible! Check the implementation below, give it a try, and let us know if it helped with speed on a TPU (my debug runs confirmed that they are numerically equivalent)
def take_along_axis(x, indices):
# [B, S, P] -> [B, S, P, D]
one_hot_indices = tf.one_hot(indices, depth=x.shape[-1], dtype=x.dtype)
# if we ignore the first two dims, this is equivalent to multiplying a matrix (one hot) by a vector (x)
# grossly abusing notation: [B, S, P, D] . [B, S, D] = [B, S, P]
gathered = tf.einsum('ijkl,ijl->ijk', one_hot_indices, x)
return gathered
@gante I tested the tf.einsum
implementation. It gave me the same performance as the one_hot
trick, which is about ~120 seq/second.
I tried it with different batch sizes but still it didn't change much.
This is a screenshot of the profiler:
I'm out of suggestions :( I suspect this is a good question for Google's XLA and TPU teams -- the problem is probably at a compiler/hardware level.
Yeah this is a weird and unexpected bug. Do you know someone we can get in contact with from Google's XLA or TPU team?
And thanks a lot for the efforts you guys put into this issue!
@sanchit-gandhi do you know a good point of contact for TPU problems?
Ping @JackCaoG for help :)
Thanks, I will try to take a look or finding someone from my team to help.
nvm, this is tf2, I only knows pt/xla lol
@sanchit-gandhi do you know a good point of contact for TPU problems?
Only for JAX on TPU, I'll ask around and see if there is anyone who can help with TF!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
System Info
latest version of transformers, Colab TPU, tensorflow 2
Who can help?
@kamalkraj @Rocketknight1 @BigBird01
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
It's currently hard to share code and access to the google bucket. But I believe any TF2 DeBERTaV2 code running on TPUs will have this issue
Expected behavior
I've been trying to train a deberta v3 model on GPU and TPUs. I got it to work on multi-node and multi-gpus using Nvidia deeplearning examples libraries https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/ I basically used the training setup and loop from the BERT code, the dataset utils from the ELECTRA code, and the model from Huggingface transformers with some changes in order to share embeddings.
On 6xA40 45gb gpus i get around 1370 sentences per seconds during training (which is lower than what Nvidia gets for Electra but it's fine).
Ok, now the problem.... on TPU i get 20 sentences per second
I traced the issue back to the tf.gather function here https://github.com/huggingface/transformers/blob/main/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py#L525
I ran TPU profiling and this is the output:
GatherV2 takes most of the time:
zoomed in pictures of the fast ops
Also, I'm not sure if this is TPU specific since on GPUs the training ~30% slower compared to regular ELECTRA.