facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.38k stars 6.4k forks source link

Shared weights are duplicated in checkpoints #2138

Open fansiawang opened 4 years ago

fansiawang commented 4 years ago

🐛 Bug

When I trained a simple transformer model by setting share_all_embeddings=True, share_decoder_input_output_embed=True, I got a same model size by setting these two parameters to false. In addition, if I train the model on CPU, then I will get different model size by setting these two parameters to false, the model size of shared one maybe half of the not shared one. I think, sharing embedding not only to reduce the parameter number, but also to reduce the model size.

I try the previous fairseq(0.8.0), I can get different model size by setting these two parameters to different value. I don't quite understand why this implementation has to be modified.

The reason why I got the same size whether sharing embedding or not is that the GPU version will change the address of embedding tensor after convert all state to CPU.

I print the address of embedding tensor in fairseq/checkpoint_utils.py.

280    print("************* before convert all state to CPU *********")
281     for k, v in model_state_dict.items():
282     ¦   if "encoder.embed_tokens.weight" in k or "decoder.output_projection.weight" in k:
283     ¦   ¦   print(k, "#########", v.data_ptr())
284      
285     # convert all state to CPU
286     state_dict = utils.move_to_cpu(state_dict)
287      
288     print("************* after convert all state to CPU *********")
289     for k, v in state_dict["model"].items():
290     ¦   if "encoder.embed_tokens.weight" in k or "decoder.output_projection.weight" in k:
291     ¦   ¦   print(k, "#########", v.data_ptr())

The result is:

************* before convert all state to CPU *********
encoder.embed_tokens.weight ######### 139854906130432
decoder.output_projection.weight ######### 139854906130432

************* after convert all state to CPU *********
encoder.embed_tokens.weight ######### 139849528733760
decoder.output_projection.weight ######### 139840124805184

The address changed after coverting all state to CPU.

To Reproduce

python train.py $data_dir \ --dropout 0.1 \ --clip-norm 0.1 \ --max-tokens $max_tokens \ --seed $seed \ --num-workers 4 \ --optimizer adafactor \ --criterion label_smoothed_cross_entropy \ --label-smoothing 0.1 \ --weight-decay 0.0 \ --lr 0.0003 \ --lr-scheduler inverse_sqrt \ --warmup-init-lr 1e-07 \ --warmup-updates 4000 \ --arch transformer \ --save-dir $model_dir \ --update-freq 8 \ --me 100 \ --log-interval 1 \ --save-interval-updates 1 \ --keep-interval-updates 1 \ --no-progress-bar \ --empty-cache-freq 64 \ --ddp-backend=no_c10d


- train on CPU or GPU by setting the following envirment :
`CPU: export CUDA_VISIBLE_DEVICES=""`
`GPU: export CUDA_VISIBLE_DEVICES="0"`

- set the shared parameter in `fairseq/models/transformer.py`

args.share_decoder_input_output_embed = getattr(args, "share_decoder_input_output_embed", True) args.share_all_embeddings = getattr(args, "share_all_embeddings", True)



### Expected behavior

CPU shared model: 659M
CPU not shared model: 1.6G
GPU shared model: 1.6G
GPU not shared model: 1.6G

But I think the GPU shared model size should be 659M.

### Environment

 - fairseq Version (e.g., 1.0 or master): lastest(master, 2020.5.12)
 - PyTorch Version (e.g., 1.0): 1.4.0
 - OS (e.g., Linux): CentOS 7.2
 - How you installed fairseq (`pip`, source):
 - Build command you used (if compiling from source): pip install --editable .
 - Python version: 3.7.3
 - CUDA/cuDNN version: 10.0
 - GPU models and configuration: 2080
 - Any other relevant information:
villmow commented 4 years ago

Does this bug cause all Transformers trained with --share-all-embeddings to currently not share embeddings after initialization?

fansiawang commented 4 years ago

When you train with --share-all-embeddings, the model will share all embeddings during training on GPU. But when saving checkpoint, it will save three same parameters into the model. I think it is better to save only one parameter because we share the embedding.