Closed JACKHAHA363 closed 3 years ago
This is the hparams
config:
batch_size: 4096
data_root: vilt_dataset/
datasets:
- coco
- vg
- sbu
- gcc
decay_power: 1
draw_false_image: 1
draw_false_text: 0
drop_rate: 0.1
end_lr: 0
exp_name: pretrain
fast_dev_run: false
get_recall_metric: false
hidden_size: 768
image_only: false
image_size: 384
learning_rate: 0.0001
load_path: ''
log_dir: result
loss_names:
irtr: 0
itm: 1
mlm: 1
mpp: 0
nlvr2: 0
vqa: 0
lr_mult: 1
max_epoch: 100
max_image_len: 200
max_steps: 100000
max_text_len: 40
mlm_prob: 0.15
mlp_ratio: 4
num_gpus: 8
num_heads: 12
num_layers: 12
num_nodes: 8
num_workers: 8
optim_type: adamw
patch_size: 32
per_gpu_batchsize: 64
precision: 16
resume_from: null
seed: 0
test_only: false
tokenizer: bert-base-uncased
train_transform_keys:
- pixelbert
val_check_interval: 1.0
val_transform_keys:
- pixelbert
vit: vit_base_patch32_384
vocab_size: 30522
vqav2_label_size: 3129
warmup_steps: 2500
weight_decay: 0.01
whole_word_masking: true
Using the official VQA checkpoint I am able to get the suggested results of 71
@JACKHAHA363 Thank you for your report. After carefully comparing the published (cleaned) version and our interval version of the source code, we found that we did joint training of pretraining losses in the internal version, which is done alternatively in the cleaned version.
I patched the code to do the joint training (https://github.com/dandelin/ViLT/commit/98a51e6058b1bcdd98ee6628ceacdd1c7325525f), please try with this version. Sorry for our mistake, the alternative training will need more iterations to converge.
@JACKHAHA363
https://tensorboard.dev/experiment/mNHxDM08R6eHKeU0JHn5vg
FYI, I have uploaded the pre-training log of ViLT to the above tensorboard.dev
link for mlm_itm + WWM 100K.
@dandelin Thanks for the swift response! Also it seems that in the current implementation, both computing the ITM and computing the MLM will involve the same forwarding procedure? Do you think removing that redundancy could potentially speed up the training?
@JACKHAHA363 Those two need different inputs. For ITM, we use unmasked inputs (and also misaligned image-text pair). So an iteration requires running the transformer three times: aligned masked text + image for MLM, aligned unmasked text + image and misaligned unmasked text + image for ITM.
@dandelin I see. thanks for your help and really impressive work!
The loss curves is normal now, and my VQA test-dev score is close to the paper. I will close this issue.
Dear Authors, Thanks for open sourcing the code. I tried pretrain 100k steps and finetune on vqav2, but my dev-test score is about 65, unlike the 70.8 on the paper.
Here is my pretrain and finetune command
Generate JSON with
here is my pretraining and finetuning tb log