Open neverland7D opened 5 years ago
The best finetune result(finetune from the pretrain model you published) I get is 56.62,83.96,90.56 which is still 1.6 lower than your reported result, furthermore, the zero-shot evaluation result from your public conceptual pretrained model is only 26.83,56.43,68.92, is it the right one? I wish to know more training details, like the command and out file of each pretrained and finetune model, and I have some questions:
- Do you use freeze param during finetune or only pretrain?
- How do you calculate the hard negative? What kind of image feature do you use to calculate the similarity(ROI feature or other features like resnet or dense net for full image)?
- How do you set the LR decay epoch for finetuning? I see from the out file that looks like 0.2 with [11,13,15,17]
- The pretrain code output log shows LR=0, is it normal? (because I see that you set different decay weight for different params)
Thanks a lot!
By the way, I finetuned image retrieval on flickr 30k with these params: Namespace(baseline=False, bert_model='bert-base-uncased', compact=False, config_file='config/bert_base_6layer_6conect.json', do_lower_case=True, evaluation_interval=1, fp16=False, freeze=-1, from_pretrained='/conceptual_pretrained_bert_base_6_layer_6_connect_freeze_0/pytorch_model_9.bin', gradient_accumulation_steps=1, in_memory=False, learning_rate=2e-05, local_rank=0, loss_scale=0, lr_scheduler='mannul', no_cuda=False, num_train_epochs=20, num_workers=9, optimizer='BertAdam', output_dir='/model/vilbert/', save_name='finetune_retrieval_2', seed=0, tasks='3', use_chunk=0, vision_scratch=False, warmup_proportion=0.1)
{ "attention_probs_dropout_prob": 0.1, "bi_attention_type": 1, "bi_hidden_size": 1024, "bi_intermediate_size": 1024, "bi_num_attention_heads": 8, "fast_mode": false, "fixed_t_layer": 0, "fixed_v_layer": 0, "fusion_method": "mul", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "in_batch_pairs": false, "initializer_range": 0.02, "intermediate_size": 3072, "intra_gate": false, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooling_method": "mul", "predict_feature": false, "t_biattention_id": [ 6, 7, 8, 9, 10, 11 ], "type_vocab_size": 2, "v_attention_probs_dropout_prob": 0.1, "v_biattention_id": [ 0, 1, 2, 3, 4, 5 ], "v_feature_size": 2048, "v_hidden_act": "gelu", "v_hidden_dropout_prob": 0.1, "v_hidden_size": 1024, "v_initializer_range": 0.02, "v_intermediate_size": 1024, "v_num_attention_heads": 8, "v_num_hidden_layers": 6, "v_target_size": 1601, "vocab_size": 30522, "with_coattention": true }
Do you use ensemble method? Maybe this is why the result from the paper is slightly better than yours.
Hi, while finetuning on flickr30k, did you encounter this error?
RuntimeError: expand(torch.cuda.FloatTensor{[64, 1, 4, 100, 2048]}, size=[64, 4, 4, 2048]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (5)
@yangapku were you able to get the solution to this ? I am also facing the same error.
Thanks!!
Error: RuntimeError: expand(torch.cuda.FloatTensor{[64, 1, 4, 100, 2048]}, size=[64, 4, 4, 2048]): the number of sizes provided (4) must be greater or equal to the number of dimensions in the tensor (5)
Fix:
in vilbert/task_utils.py, make these changes under ForwardModelsVal
and ForwardModelsTrain
:
max_num_bbox=features.size(1)
to max_num_bbox=features.size(2)
and remove unsqueeze in features definition -
features = features.unsqueeze(1).expand(batch_size, num_options, max_num_bbox, 2048).contiguous().view(-1, max_num_bbox, 2048)
to
features = features.expand(batch_size, num_options, max_num_bbox, 2048).contiguous().view(-1, max_num_bbox, 2048)
similarly remove unsqueeze for spatials and image_mask as well.
The best finetune result(finetune from the pretrain model you published) I get is 56.62,83.96,90.56 which is still 1.6 lower than your reported result, furthermore, the zero-shot evaluation result from your public conceptual pretrained model is only 26.83,56.43,68.92, is it the right one? I wish to know more training details, like the command and out file of each pretrained and finetune model, and I have some questions:
Thanks a lot!
By the way, I finetuned image retrieval on flickr 30k with these params: Namespace(baseline=False, bert_model='bert-base-uncased', compact=False, config_file='config/bert_base_6layer_6conect.json', do_lower_case=True, evaluation_interval=1, fp16=False, freeze=-1, from_pretrained='/conceptual_pretrained_bert_base_6_layer_6_connect_freeze_0/pytorch_model_9.bin', gradient_accumulation_steps=1, in_memory=False, learning_rate=2e-05, local_rank=0, loss_scale=0, lr_scheduler='mannul', no_cuda=False, num_train_epochs=20, num_workers=9, optimizer='BertAdam', output_dir='/model/vilbert/', save_name='finetune_retrieval_2', seed=0, tasks='3', use_chunk=0, vision_scratch=False, warmup_proportion=0.1)
{ "attention_probs_dropout_prob": 0.1, "bi_attention_type": 1, "bi_hidden_size": 1024, "bi_intermediate_size": 1024, "bi_num_attention_heads": 8, "fast_mode": false, "fixed_t_layer": 0, "fixed_v_layer": 0, "fusion_method": "mul", "hidden_act": "gelu", "hidden_dropout_prob": 0.1, "hidden_size": 768, "in_batch_pairs": false, "initializer_range": 0.02, "intermediate_size": 3072, "intra_gate": false, "max_position_embeddings": 512, "num_attention_heads": 12, "num_hidden_layers": 12, "pooling_method": "mul", "predict_feature": false, "t_biattention_id": [ 6, 7, 8, 9, 10, 11 ], "type_vocab_size": 2, "v_attention_probs_dropout_prob": 0.1, "v_biattention_id": [ 0, 1, 2, 3, 4, 5 ], "v_feature_size": 2048, "v_hidden_act": "gelu", "v_hidden_dropout_prob": 0.1, "v_hidden_size": 1024, "v_initializer_range": 0.02, "v_intermediate_size": 1024, "v_num_attention_heads": 8, "v_num_hidden_layers": 6, "v_target_size": 1601, "vocab_size": 30522, "with_coattention": true }