facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.29k stars 6.39k forks source link

[FSDP] Size mismatch when finetuning mBART with `translation_multi_simple_epoch` in FSDP #3464

Open thpun opened 3 years ago

thpun commented 3 years ago

🐛 Bug

Got errors when loading mBART.cc25 pretrained model for fine-tuning on translation_multi_simple_epoch in FSDP.

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd
    lang_pairs=<comma-separated list of lang pairs to be trained>
    PREFIX=mbart-fsdp
    DATA=/path/to/train/data
    lang_list=models/$PREFIX/lang_list
    MODEL=models/mbart.cc25/model.pt
    CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train $DATA \
    --finetune-from-model $MODEL \
    --encoder-normalize-before --decoder-normalize-before \
    --arch mbart_large --layernorm-embedding \
    --task translation_multi_simple_epoch \
    --sampling-method "temperature" \
    --sampling-temperature 5 \
    --encoder-langtok "src" --decoder-langtok \
    --lang-dict "$lang_list" --lang-pairs "$lang_pairs" \
    --source-dict $DATA/dict.en_XX.txt --target-dict $DATA/dict.en_XX.txt \
    --checkpoint-activations --fp16 --no-reshard-after-forward \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
    --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
    --lr-scheduler inverse_sqrt --lr 6e-05 --stop-min-lr -1 --warmup-updates 2000 \
    --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
    --max-tokens 1920 --update-freq 2 --upsample-primary 2 \
    --save-interval-updates 5000 --keep-interval-updates 10 --keep-best-checkpoints 10 \
    --patience 5 --max-epoch 150 \
    --seed 222 --log-format simple --log-interval 10 --ddp-backend fully_sharded \
    --save-dir models/$PREFIX
  2. See error
    
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | main:en_XX-zh_CN src_langtok: 250004; tgt_langtok: 250025
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 5,983 examples from: train-data/multilingual_v1-tagged/valid.en_XX-zh_CN.en_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 5,983 examples from: train-data/multilingual_v1-tagged/valid.en_XX-zh_CN.zh_CN
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | train-data/multilingual_v1-tagged valid en_XX-zh_CN 5983 examples
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | main:zh_CN-en_XX src_langtok: 250025; tgt_langtok: 250004
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 5,982 examples from: train-data/multilingual_v1-tagged/valid.zh_CN-en_XX.zh_CN
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 5,982 examples from: train-data/multilingual_v1-tagged/valid.zh_CN-en_XX.en_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | train-data/multilingual_v1-tagged valid zh_CN-en_XX 5982 examples
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | main:en_XX-es_XX src_langtok: 250004; tgt_langtok: 250005
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 3,003 examples from: train-data/multilingual_v1-tagged/valid.en_XX-es_XX.en_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 3,003 examples from: train-data/multilingual_v1-tagged/valid.en_XX-es_XX.es_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | train-data/multilingual_v1-tagged valid en_XX-es_XX 3003 examples
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | main:es_XX-en_XX src_langtok: 250005; tgt_langtok: 250004
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 3,003 examples from: train-data/multilingual_v1-tagged/valid.en_XX-es_XX.es_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.data_utils | loaded 3,003 examples from: train-data/multilingual_v1-tagged/valid.en_XX-es_XX.en_XX
    2021-04-11 09:43:04 | INFO | fairseq.data.multilingual.multilingual_data_manager | train-data/multilingual_v1-tagged valid es_XX-en_XX 3003 examples
    2021-04-11 09:43:04 | INFO | fairseq.utils | ***********************CUDA enviroments for all 4 workers***********************
    2021-04-11 09:43:04 | INFO | fairseq.utils | rank   0: capabilities =  7.0  ; total memory = 31.719 GB ; name = Tesla V100-SXM3-32GB
    2021-04-11 09:43:04 | INFO | fairseq.utils | rank   1: capabilities =  7.0  ; total memory = 31.719 GB ; name = Tesla V100-SXM3-32GB
    2021-04-11 09:43:04 | INFO | fairseq.utils | rank   2: capabilities =  7.0  ; total memory = 31.719 GB ; name = Tesla V100-SXM3-32GB
    2021-04-11 09:43:04 | INFO | fairseq.utils | rank   3: capabilities =  7.0  ; total memory = 31.719 GB ; name = Tesla V100-SXM3-32GB
    2021-04-11 09:43:04 | INFO | fairseq.utils | ***********************CUDA enviroments for all 4 workers***********************
    2021-04-11 09:43:04 | INFO | fairseq_cli.train | training on 4 devices (GPUs/TPUs)
    2021-04-11 09:43:04 | INFO | fairseq_cli.train | max tokens per device = 1920 and max sentences per device = None
    2021-04-11 09:43:04 | INFO | fairseq.checkpoint_utils | loading pretrained model from models/mbart.cc25/model.pt: optimizer, lr scheduler, meters, dataloader will be reset
    2021-04-11 09:43:04 | INFO | fairseq.trainer | Preparing to load checkpoint models/mbart.cc25/model.pt
    Traceback (most recent call last):
    File "/opt/conda/bin/fairseq-train", line 33, in <module>
    sys.exit(load_entry_point('fairseq', 'console_scripts', 'fairseq-train')())
    File "/workspace/fairseq/fairseq_cli/train.py", line 491, in cli_main
    distributed_utils.call_main(cfg, main)
    File "/workspace/fairseq/fairseq/distributed/utils.py", line 344, in call_main
    torch.multiprocessing.spawn(
    File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
    File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    while not context.join():
    File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
    raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException:

-- Process 3 terminated with the following error: Traceback (most recent call last): File "/workspace/fairseq/fairseq/trainer.py", line 453, in load_checkpoint self.model.load_state_dict( File "/workspace/fairseq/fairseq/distributed/fully_sharded_data_parallel.py", line 76, in load_state_dict return super().load_state_dict(state_dict, strict=strict) File "/workspace/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 604, in load_state_dict return self.module.load_state_dict(state_dict, strict) File "/workspace/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 242, in load_state_dict return super().load_state_dict(state_dict, strict) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for FlattenParamsWrapper: Missing key(s) in state_dict: "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.k_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.k_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.v_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.v_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.q_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.q_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.k_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.v_proj.bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.q_proj.bias", "_fpw_module.decoder.output_projection.weight". Unexpected key(s) in state_dict: "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.layer_norms.0.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.layer_norms.0.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.layer_norms.1.weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.layer_norms.1.bias", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.encoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.2._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.3._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.4._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.5._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.6._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.7._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.8._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.9._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.10._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.in_proj_weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.self_attn.in_proj_bias", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_weight", "_fpw_module.decoder.layers.11._fsdp_wrapped_module._fpw_module.encoder_attn.in_proj_bias". size mismatch for _fpw_module.encoder.embed_tokens.weight: copying a param with shape torch.Size([250027, 1024]) from checkpoint, the shape in current model is torch.Size([250026, 1024]). size mismatch for _fpw_module.decoder.embed_tokens.weight: copying a param with shape torch.Size([250027, 1024]) from checkpoint, the shape in current model is torch.Size([250026, 1024]).

During handling of the above exception, another exception occurred:

Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap fn(i, *args) File "/workspace/fairseq/fairseq/distributed/utils.py", line 328, in distributed_main main(cfg, **kwargs) File "/workspace/fairseq/fairseq_cli/train.py", line 145, in main extra_state, epoch_itr = checkpoint_utils.load_checkpoint( File "/workspace/fairseq/fairseq/checkpoint_utils.py", line 204, in load_checkpoint extra_state = trainer.load_checkpoint( File "/workspace/fairseq/fairseq/trainer.py", line 465, in load_checkpoint raise Exception( Exception: Cannot load model parameters from checkpoint models/mbart.cc25/model.pt; please ensure that the architectures match.


### Expected behavior

It is expected that `translation_multi_simple_epoch` works on mBART in fully sharded data parallel (FSDP), as if in no_c10d.

### Environment

 - fairseq Version (e.g., 1.0 or master): master, ee0d5a0f65a25e5f5372776402aac5cb9c4adbf1
 - PyTorch Version (e.g., 1.0) 1.8.0a0+52ea372
 - OS (e.g., Linux): Linux
 - How you installed fairseq (`pip`, source): Source
 - Build command you used (if compiling from source):

conda install gcc_linux-64 gxx_linux-64 git clone https://github.com/pytorch/fairseq.git git clone https://github.com/facebookresearch/fairscale cd fairseq pip install opencc nni tensorboardX pyarrow pip install -U numpy cython apt update apt-get install -y screen llvm-9 DS_BUILD_CPU_ADAM=1 DS_BUILD_UTILS=1 pip install deepspeed --global-option="build_ext" --global-option="-j8" pip install --editable . python setup.py build_ext --inplace cd ../fairscale pip install -r requirements.txt pip install -e .


 - Python version: 3.8.5
 - CUDA/cuDNN version: 11.0
 - GPU models and configuration: V100
 - fairscale version: 0.3.3, e969397608e69222ebd6a034c5fbd958e0e9689d
ZeguanXiao commented 2 years ago

@thpun Do you address this bug? I encounter the same error when I fine-tuning mBART with translation_from_pretrained_bart task. When I try to train a model from scratch, the FSDP is fine.

thpun commented 2 years ago

No. I just didnt use FSDP for finetuning mBART.

ZeguanXiao commented 2 years ago

Oh, it's so regrettable.