ashkamath / mdetr

Apache License 2.0
969 stars 125 forks source link

How to run fine-tuning on VQA2 dataset? #43

Closed TopCoder2K closed 2 years ago

TopCoder2K commented 3 years ago

Experiments with VQA v.2 dataset are described in Appendix E of the article. But it's not clear from main.py and run_with_submitit.py files how to run the fine-tuning (I've tried to write the same command that is used for fine-tuning on CLEVR). I've also found vqa_coco_format.py but it seems like preparation of the data, not fine-tuning itself. Also, I've encountered using build_dataset function in main.py and I don't see VQA v2 in the function :( Could you please explain how to do so?

UPD 1 (09.27.21): I've downloaded COCO and VQA v2 datasets and ran

python scripts/fine-tuning/vqa_coco_format.py --data_path VQA_v2_dataset/ --img_path COCO_dataset/images/ --coco_path COCO_dataset/

And the processing has finished correctly. Now I'm thinking how to write VQA v2 dataset script...

UPD 2 (10.03.21) It seems I managed to implement all the necessary classes and fix the code. I'm currently doing an experiment eval -> train on vqa2 -> eval. As soon as it successfully finishes I'll push the code into my fork of the repo.

UPD 3 (10.03.21) Yeah, it works! Here is the link: https://github.com/TopCoder2K/mdetr. I haven't written any documentation because I'm not sure that fine-tuning on VQA is useful to anybody.)) If you have any question, please ask here :)

TopCoder2K commented 3 years ago

Good afternoon, @ashkamath!

Thank you for the great MDETR! I've been trying to learn MDETR on the VQA2 dataset. As I wrote above, I managed to implement the code that is needed to run training on the VQA2. Then I decided to conduct your experiment from the Appendix E of the article. I fine-tuned on the GQA balanced with --no-detection option for 10 epochs and then fine-tuned on the VQA2 for 25 epochs. But the results are quite strange. It seems that the model hasn't learned, the loss on the GQA has increased and on the VQA almost hasn't changed. Evaluation during training was performed on the val and minival splits of the GQA and VQA2 respectively. image image

Here are the commands I've used. Fine-tuning on the GQA:

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --resume https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth --batch_size 4 --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint

Fine-tuning on the VQA2:

python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1  --epochs 25 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --backbone resnet101 --load ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/19311/BEST_checkpoint.pth --batch_size 4 --no_aux_loss --no_contrastive_align_loss --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint --do_qa_with_qa_fine-tuned

Evaluation on the VQA2:

python main.py --dataset_config configs/vqa2.json --eval --do_qa --split_qa_heads --no_contrastive_align_loss --no_aux_loss --no_detection --backbone resnet101 --qa_loss_coef 25 --resume ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/5063/BEST_checkpoint.pth

The only significant difference is that I did 10 epochs on the GQA balanced while you did 5 epochs on the GQA all. But can it have such an impact? I think that hyperparameters can be wrong... Could you please provide what options and what hyperparameters values you used? Also, it's interesting how you loaded the model from BEST_checkpoint.pth after fine-runing on the GQA? I used the --loadoption, and mismatching heads are deleted before running load_state_dict() (for example head for the types of questions).

TopCoder2K commented 2 years ago

An interesting thing I've noticed is that when running GQA on 5 epochs, the graphs are different, and the training goes! But I still have no good results on VQA2. Comparing to my previous comment, I've improved the dataset processing, fixed a small bug, used --ema. So, I have no idea why the training doesn't go well...

Here are the graphs of the total loss and some others metrics and losses: GQA, 5 epochs, with --no_detection, torch.set_deterministic(True) image VQA2, 10 epochs, torch.set_deterministic(True), after fine-tuning on GQA balanced for 5 epochs with --no_detection image

And here are the commands I used (running on GQA and VQA respectively):

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 1 --nodes 1 --ema --epochs 5 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load pretrained_resnet101_checkpoint.pth --batch_size 4 --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint
python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load ~/MDETR/mdetr/checkpoint/pchelintsev/experiments/26220/BEST_checkpoint.pth --batch_size 4 --no_aux_loss --no_contrastive_align_loss --no_detection --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint --do_qa_with_qa_fine-tuned

And I still have the request: Could you please provide what options and what hyperparameters values you used?

ashkamath commented 2 years ago

Hi! For GQA, there was a typo in the paper and in the appendix, which was fixed in the main paper but I seem to have forgotten to update the appendix - After pre-training on modulated detection, we fine tune with the QA queries on GQA all for 125 epochs, with the command:

python run_with_submitit.py --dataset_config configs/gqa.json --ngpus 8 --ema --epochs 125 --epoch_chunks 25 --do_qa --split_qa_heads --lr_drop 150 --load https://zenodo.org/record/4721981/files/pretrained_resnet101_checkpoint.pth --nodes 4 --batch_size 4 --no_aux_loss --qa_loss_coef 25 --lr 1.4e-4 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 here, it is important to have this QA loss coefficient that puts more weight on the QA losses than on the detection losses.

Using this model, we then fine tune on VQA for 25 epochs (we used a bce loss on the qa head):

python run_with_submitit.py --backbone "resnet101" --dataset_config configs/vqa2.json --num_queries 100 --batch_size 4 --num_workers 5 --schedule linear_with_warmup --text_encoder_type roberta-base --ngpus 8 --nodes 4 --ema --do_qa --load path/to/gqa/model --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --bce_qa --epochs 25

So in short, if youre training on GQA all, dont use --no_detection, and then initialize from best model after 125 epochs and finetune on vqa for 25.

Hope this helps! Feel free to get back with questions.

Best, Aish

TopCoder2K commented 2 years ago

Thank you for the reply!

Aa, you've fine tuned it on GQA all for 125 epochs and you haven't used --no_detection, okay! By the way, there is no --bce_qa flag in main.py... Anyway, I haven't enough resources to train on GQA all for 125 epochs. I'll try the second option: fine-tuning the pre-trained model. And could you provide the command you've used to fine tune the pre-trained model on the VQA2?

TopCoder2K commented 2 years ago

I've tried the second option by running the following command: python run_with_submitit.py --dataset_config configs/vqa2.json --ngpus 1 --nodes 1 --ema --epochs 10 --epoch_chunks 25 --do_qa --split_qa_heads --load pretrained_resnet101_checkpoint.pth --batch_size 4 --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --output-dir ~/MDETR/mdetr/checkpoint Here I've tried to adjust it according to your command for fine-tuning after GQA fine-tuning

python run_with_submitit.py --backbone "resnet101" --dataset_config configs/vqa2.json --num_queries 100 --batch_size 4 --num_workers 5 --schedule linear_with_warmup --text_encoder_type roberta-base --ngpus 8 --nodes 4 --ema --do_qa --load path/to/gqa/model --no_aux_loss --no_detection --lr 7e-5 --lr_backbone 1.4e-5 --text_encoder_lr 7e-5 --bce_qa --epochs 25 

So, I have deleted qa_coeff, changed lr, removed lr_drop, turned on contrastive_align_loss. Unfortunately, the results are the not what I would like. image

Could you please post the exact command you used?

UPD1: Oh, I've just realized that I can use your gqa_resnet101_checkpoint.pth! I've already set up the experiments with this checkpoint! Looking forward to the results!

UPD2: Yeah, I've got good results, thank you!