dandelin / ViLT

Code for the ICML 2021 (long talk) paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"
Apache License 2.0
1.39k stars 207 forks source link

Finetuning VQA from checkpoint has unmatched keys #8

Closed JACKHAHA363 closed 3 years ago

JACKHAHA363 commented 3 years ago

Command

$PYTHONBIN run.py with data_root=dataset  \
        num_gpus=1 num_nodes=1 task_finetune_vqa \
        per_gpu_batchsize=64 load_path="weights/vilt_200k_mlm_itm.ckpt"

And the error is in trainer.fit()

  File "/data/home/lyuchen/miniconda/envs/vilt/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 184, in setu[52/1882]
g
    self.trainer.checkpoint_connector.restore_weights(model)
  File "/data/home/lyuchen/miniconda/envs/vilt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 63,
 in restore_weights
    self.hpc_load(checkpoint_path, self.trainer.on_gpu)
  File "/data/home/lyuchen/miniconda/envs/vilt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 336
, in hpc_load
    self.restore_model_state(model, checkpoint)
  File "/data/home/lyuchen/miniconda/envs/vilt/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 119
, in restore_model_state
    model.load_state_dict(checkpoint['state_dict'])
  File "/data/home/lyuchen/miniconda/envs/vilt/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ViLTransformerSS:
        Missing key(s) in state_dict: "vqa_classifier.0.weight", "vqa_classifier.0.bias", "vqa_classifier.1.weight", "vqa_classifier.1.bias", "vqa
_classifier.3.weight", "vqa_classifier.3.bias".
        Unexpected key(s) in state_dict: "mlm_score.bias", "mlm_score.transform.dense.weight", "mlm_score.transform.dense.bias", "mlm_score.transf
orm.LayerNorm.weight", "mlm_score.transform.LayerNorm.bias", "mlm_score.decoder.weight", "itm_score.fc.weight", "itm_score.fc.bias".
JACKHAHA363 commented 3 years ago

The problem is that the lightning automatically detect the hpc_ckpt from auto-resubmit. After I remove it, it's fine.