huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.01k stars 27.01k forks source link

weird large memory usage of mbert model #10833

Closed dorost1234 closed 3 years ago

dorost1234 commented 3 years ago

Environment info

Who can help

albert, bert, xlm: @LysandreJik

Information

I am using mbert model, this is 110 M params, I am testing the pretrianing codes with mt5-small and compare this with mbert, mbert weirdly use a lot of memory, I need to reduce 1/2 of batch_size on the same machine I train my mt5-small model which is 3x larger than mbert model. This is weird that mbert while being 1/3 of mt5-small require large memory.

To reproduce

Steps to reproduce the behavior: python run_mlm.py --model_name_or_path bert-base-multilingual-uncased --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --do_eval --output_dir output --per_device_train_batch_size 88 --fp16 --max_seq_length 128

Expected behavior

larger batch size should be possible with mbert

github-actions[bot] commented 3 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

yurymalkov commented 3 years ago

One possible reason for that is bigger tokenizer cardinality and the fact that HF code predicts all tokens, not only one that are missing, so this operation might dominate overall runtime/memory while doing no useful output.

In my case it was the case and I did a small fix here https://github.com/yurymalkov/transformers/commit/0fe0725c0f7fcc13df698bba1bd01847c1494e43 that ended up in 6X larger batch at my memory usage (1060 GTX) and 3-4X times faster training of a small mbert model, but it causes some tests to fail (I think mainly due to pytorch jit failure which sees named tensor in slice, secondly to missing output of non-masked tokens - though who cares about them?).