rusiaaman / XLnet-gen

XLNet for generating language.
MIT License
165 stars 20 forks source link

Inference time #1

Closed astariul closed 5 years ago

astariul commented 5 years ago

Thanks for the code and the article !

I've tried running your code, and it works, but the inference time is very, very long : ~45min for 1 sample...

I understand it is because we need to recompute all the attention of non-target tokens every 2 steps and it's running on CPU, but it is still too long.

Can I run this code on GPU / Colab TPU ? If yes, how ?

rusiaaman commented 5 years ago

As you pointed out we have to recalculate hidden states each time new token is sampled (I will change it to two tokens every time if it shows stable performance, which it should), but 45 mins is still too slow. It will help to reduce max_mem_length and num_toks_pred. Setting both of them to 128 instead of 512 should ideally speed it up by sixteen times on a cpu.

It takes 98 seconds for a sample with 256 max_mem_lenght and num_toks_pred for a Tesla K80 of colab. Probably much lower for Tesla T4.

Colab notebook: https://colab.research.google.com/drive/12u-CmB9evMIASNOqJtDW26gmNvSgepBv

It is not worth doing inference on TPU because of multiple factors like uploading the input data on gcloud, the need of fixed batch_size (can't take less than batch_size), general availability, and the time to initiate a TPU cluster.

rusiaaman commented 5 years ago

This issue is not relevant now. I have discovered a mistake in my code and I have rectified it. No longer re-calculating hidden states now and generating tokens in autoregressive fashion is possible without output degrading (given large enough context).

astariul commented 5 years ago

Thank you for updating your code and medium article 👍