dandelin / ViLT

Code for the ICML 2021 (long talk) paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"
Apache License 2.0
1.4k stars 208 forks source link

Unable to reproduce the 100k results #12

Closed JACKHAHA363 closed 3 years ago

JACKHAHA363 commented 3 years ago

Dear Authors, Thanks for open sourcing the code. I tried pretrain 100k steps and finetune on vqav2, but my dev-test score is about 65, unlike the 70.8 on the paper.

Here is my pretrain and finetune command

python run.py with data_root=vilt_dataset/ \
    num_gpus=8 num_nodes=8 task_mlm_itm whole_word_masking=True step100k \
    per_gpu_batchsize=64 exp_name=pretrain 
python run.py with data_root=vilt_dataset/ \
    num_gpus=8 num_nodes=1 task_finetune_vqa_randaug \
    per_gpu_batchsize=32 load_path="result/pretrain_seed0_from_/version_0/checkpoints/last.ckpt" \
    exp_name=vqa_finetune

Generate JSON with

python run.py with data_root=vilt_dataset/ \
    num_gpus=4 num_nodes=1 task_finetune_vqa \
    per_gpu_batchsize=256 load_path="result/vqa_finetune_seed0_from_last/version_0/checkpoints/last.ckpt" \
    test_only=True  exp_name="test_vqa"

here is my pretraining and finetuning tb log

Screen Shot 2021-06-10 at 6 34 22 PM Screen Shot 2021-06-10 at 6 34 28 PM Screen Shot 2021-06-10 at 6 35 14 PM
JACKHAHA363 commented 3 years ago

This is the hparams

config:
  batch_size: 4096
  data_root: vilt_dataset/
  datasets:
  - coco
  - vg
  - sbu
  - gcc
  decay_power: 1
  draw_false_image: 1
  draw_false_text: 0
  drop_rate: 0.1
  end_lr: 0
  exp_name: pretrain
  fast_dev_run: false
  get_recall_metric: false
  hidden_size: 768
  image_only: false
  image_size: 384
  learning_rate: 0.0001
  load_path: ''
  log_dir: result
  loss_names:
    irtr: 0
    itm: 1
    mlm: 1
    mpp: 0
    nlvr2: 0
    vqa: 0
  lr_mult: 1
  max_epoch: 100
  max_image_len: 200
  max_steps: 100000
  max_text_len: 40
  mlm_prob: 0.15
  mlp_ratio: 4
  num_gpus: 8
  num_heads: 12
  num_layers: 12
  num_nodes: 8
  num_workers: 8
  optim_type: adamw
  patch_size: 32
  per_gpu_batchsize: 64
  precision: 16
  resume_from: null
  seed: 0
  test_only: false
  tokenizer: bert-base-uncased
  train_transform_keys:
  - pixelbert
  val_check_interval: 1.0
  val_transform_keys:
  - pixelbert
  vit: vit_base_patch32_384
  vocab_size: 30522
  vqav2_label_size: 3129
  warmup_steps: 2500
  weight_decay: 0.01
  whole_word_masking: true
JACKHAHA363 commented 3 years ago

Using the official VQA checkpoint I am able to get the suggested results of 71

dandelin commented 3 years ago

@JACKHAHA363 Thank you for your report. After carefully comparing the published (cleaned) version and our interval version of the source code, we found that we did joint training of pretraining losses in the internal version, which is done alternatively in the cleaned version.

I patched the code to do the joint training (https://github.com/dandelin/ViLT/commit/98a51e6058b1bcdd98ee6628ceacdd1c7325525f), please try with this version. Sorry for our mistake, the alternative training will need more iterations to converge.

dandelin commented 3 years ago

@JACKHAHA363 https://tensorboard.dev/experiment/mNHxDM08R6eHKeU0JHn5vg FYI, I have uploaded the pre-training log of ViLT to the above tensorboard.dev link for mlm_itm + WWM 100K.

JACKHAHA363 commented 3 years ago

@dandelin Thanks for the swift response! Also it seems that in the current implementation, both computing the ITM and computing the MLM will involve the same forwarding procedure? Do you think removing that redundancy could potentially speed up the training?

dandelin commented 3 years ago

@JACKHAHA363 Those two need different inputs. For ITM, we use unmasked inputs (and also misaligned image-text pair). So an iteration requires running the transformer three times: aligned masked text + image for MLM, aligned unmasked text + image and misaligned unmasked text + image for ITM.

JACKHAHA363 commented 3 years ago

@dandelin I see. thanks for your help and really impressive work!

JACKHAHA363 commented 3 years ago

The loss curves is normal now, and my VQA test-dev score is close to the paper. I will close this issue.