Open fansiawang opened 4 years ago
Does this bug cause all Transformers trained with --share-all-embeddings
to currently not share embeddings after initialization?
When you train with --share-all-embeddings
, the model will share all embeddings during training on GPU. But when saving checkpoint, it will save three same parameters into the model. I think it is better to save only one parameter because we share the embedding.
🐛 Bug
When I trained a simple transformer model by setting
share_all_embeddings=True, share_decoder_input_output_embed=True
, I got a same model size by setting these two parameters to false. In addition, if I train the model on CPU, then I will get different model size by setting these two parameters to false, the model size of shared one maybe half of the not shared one. I think, sharing embedding not only to reduce the parameter number, but also to reduce the model size.I try the previous fairseq(0.8.0), I can get different model size by setting these two parameters to different value. I don't quite understand why this implementation has to be modified.
The reason why I got the same size whether sharing embedding or not is that the GPU version will change the address of embedding tensor after convert all state to CPU.
I print the address of embedding tensor in
fairseq/checkpoint_utils.py
.The result is:
The address changed after coverting all state to CPU.
To Reproduce
python train.py $data_dir \ --dropout 0.1 \ --clip-norm 0.1 \ --max-tokens $max_tokens \ --seed $seed \ --num-workers 4 \ --optimizer adafactor \ --criterion label_smoothed_cross_entropy \ --label-smoothing 0.1 \ --weight-decay 0.0 \ --lr 0.0003 \ --lr-scheduler inverse_sqrt \ --warmup-init-lr 1e-07 \ --warmup-updates 4000 \ --arch transformer \ --save-dir $model_dir \ --update-freq 8 \ --me 100 \ --log-interval 1 \ --save-interval-updates 1 \ --keep-interval-updates 1 \ --no-progress-bar \ --empty-cache-freq 64 \ --ddp-backend=no_c10d
args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", True) args.share_all_embeddings = getattr(args, "share_all_embeddings", True)