keras-team / keras-nlp

Modular Natural Language Processing workflows with Keras
Apache License 2.0
761 stars 228 forks source link

DeBERTa runs very slowly #606

Open chenmoneygithub opened 1 year ago

chenmoneygithub commented 1 year ago

DeBERTa runs very slowly on both TPU and GPU:

Comparatively, for BERT small:

GPU might be fine due to the model size diff, but TPU is not behaving normally. My suspicion is there are something not compatible with XLA.

mattdangerw commented 1 year ago

I've noticed I also see way more parameters than the advertised amount from https://github.com/microsoft/DeBERTa.

E.g. according to the github the base variant should have 86M parameters, but I see 183M when I print them out. This appears to be true on huggingface too, so may be unrelated to the issues Chen is pointing out.

mattdangerw commented 1 year ago

@abheesht17 cc on this one too. Let us know if you have any thoughts!

abheesht17 commented 1 year ago

Hey, @mattdangerw, @chenmoneygithub! I am not sure why this is happening, but I may have some answers.

This is the same number as returned by model.summary() on our model:

...
==================================================================================================
Total params: 70,682,112
Trainable params: 70,682,112
Non-trainable params: 0
__________________________________________________________________________________________________

Now, I was a bit curious and did some calculations. The numbers in brackets are the advertised numbers.

#xsmall
total_params - token_emb_params = 70,682,112 - 49,190,400 = 21,491,712 (22M)

#small
total_params - token_emb_params = 141,304,320 -  98,380,800 = 42,923,520 (44M)

# base
total_params - token_emb_params = 183,831,552 - 98,380,800 = 85,450,752 (86M)

The advertised numbers on the repo omit the token embedding parameters. So, I don't think this should be an issue.

chenmoneygithub commented 1 year ago

Our GPU performance seems to be fine, check this colab for comparison: https://colab.research.google.com/gist/chenmoneygithub/ca38f7132fc17c85511e612d09ed686c/deberta-checks.ipynb