zihangdai / xlnet

XLNet: Generalized Autoregressive Pretraining for Language Understanding
Apache License 2.0
6.16k stars 1.18k forks source link

OOM ERROR when using local batch size=128 on TPUv3-8 #259

Open GhaliaRehawi opened 4 years ago

GhaliaRehawi commented 4 years ago

Hi, I am trying to train XLNet on protein sequences. I am running into OOM error when running the script train.py using a TPUv3-8 with train_batch_size=128. (I also get OOM error with train batch size 64, 48, but not with 32, 16). In the paper it is mentioned: "Specifically, we train on 512 TPU v3 chips for 500K steps with an Adam weight decay optimizer, linear learning rate decay, and a batch size of 8192, which takes about 5.5 days." If I understand this correctly then the local batch size used is also 128= (8192/(512/8)) and I shouldn't get an OOM error. for context, am using TPUv3-8 (version 1.14.1.dev20190518) and a cloud instance vm both in us-central1-a and Tensorflow version 1.13.1 For the data preprocessing I am using the script data_utils and it runs with no problem. Here are the command am using for both preprocessing and training :

python xlnet/data_utils.py \ --use_tpu=True \ --save_dir=proc_data_bsz128/example \ --bsz_per_host=128 \ --num_core_per_host=8 \ --seq_len=512 \ --reuse_len=256 \ --input_glob=testdata_xlnet.txt \ --num_passes=20 \ --bi_data=True \ --sp_path=sp.model \ --mask_alpha=6 \ --mask_beta=1 \ --uncased=False \ --num_predict=85

python xlnet/train.py \ --use_tpu=True \ --tpu=name \ --record_info_dir=$DATA_DIR \ --save_steps=1000 \ --model_dir=$MODEL_DIR \ --train_batch_size=128 \ --seq_len=512 \ --reuse_len=256 \ --mem_len=384 \ --perm_size=256 \ --n_layer=24 \ --d_model=1024 \ --d_embed=1024 \ --n_head=16 \ --d_head=64 \ --d_inner=4096 \ --untie_r=True \ --mask_alpha=6 \ --mask_beta=1 \ --num_predict=85

$DATA_DIR and $MODEL_DIR are google bucket directories. Is there something am missing here? Thanks for your help in advance.

fangli80 commented 3 years ago

I think they mean 512 TPUs and each TPU has 8 cores. So each TPU core has two sequences