airsplay / lxmert

PyTorch code for EMNLP 2019 paper "LXMERT: Learning Cross-Modality Encoder Representations from Transformers".
MIT License
923 stars 157 forks source link

Pretrained model that wasn't pretrained on GQA #85

Open yonatanbitton opened 3 years ago

yonatanbitton commented 3 years ago

Hello First - thanks for this amazing repo.

I need a pretrained model that I can later fine-tune it on the GQA dataset, but I need a version that wasn't pretrained before on the GQA specifically.

I've tried 2 different approaches, both failed. 1) Finding a pretrained version that I can use: I found this post: https://github.com/airsplay/lxmert/issues/7 You provide a model that was pretrained on VQA. I've tried to use this version, by using the gqa_finetune.bash code and just replace the model in --loadLXMERTQA snap/pretrained/model to the provided model ( --loadLXMERTQA snap/pretrained/pretrained_vqa_model), but I receive the following error why the model creates the GQA class:

Traceback (most recent call last):
 File "src/tasks/gqa.py", line 181, in <module>
  gqa = GQA()
 File "src/tasks/gqa.py", line 57, in __init__
  label2ans=self.train_tuple.dataset.label2ans)
 File "/data/users/yonatab/lxmert/src/pretrain/qa_answer_table.py", line 132, in load_lxmert_qa
  ans_weight = answer_state_dict['logit_fc.3.weight']
KeyError: 'logit_fc.3.weight'

2) Pretraining it myself. I took the lxmert_pretrain.bash and removed the vgnococo from --train, to make it pretrain on the rest that are not GQA. 2a) when I try to pretrain it only a single GPU with this CMD: bash run/lxmert_pretrain_no_gqa.bash 2 I receive out of memory error. 2b) When I try to pretrain it on > 1 GPUs, for example 2 GPUs using this CMD: bash run/lxmert_pretrain_no_gqa.bash 4,3 --multiGPU I recieve this wierd stacktrace:

Total Iters: 378020
Warm up Iters: 18901
  0%|                                                                                                                 | 0/18901 [00:03<?, ?it/s]
Traceback (most recent call last):
  File "src/pretrain/lxmert_pretrain.py", line 435, in <module>
    lxmert.train(train_tuple, valid_tuple)
  File "src/pretrain/lxmert_pretrain.py", line 332, in train
    loss, losses, logit = self.train_batch(optim, batch)
  File "src/pretrain/lxmert_pretrain.py", line 291, in train_batch
    loss, losses, ans_logit = self.forward(batch)
  File "src/pretrain/lxmert_pretrain.py", line 285, in forward
    feats, pos, obj_labels, matched_labels, ans
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 155, in forward
    outputs = self.parallel_apply(replicas, inputs, kwargs)
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 165, in parallel_apply
    return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in parallel_apply
    output.reraise()
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/_utils.py", line 395, in reraise
    raise self.exc_type(msg)
StopIteration: Caught StopIteration in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py", line 60, in _worker
    output = module(*input, **kwargs)
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/users/yonatab/lxmert/src/lxrt/modeling.py", line 926, in forward
    visual_feats=(visual_feats, pos),
  File "/data/users/yonatab/lxmert/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/users/yonatab/lxmert/src/lxrt/modeling.py", line 864, in forward
    extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
StopIteration

Do you have an idea how can I solve it? Or is there any change that you have a pretrained model that is not familiar with the GQA data that I can use?

yonatanbitton commented 3 years ago

Update: Downgrading torch to 1.4.0 seems to solve the problem in 2b).

However, i'm afraid that after pre-training I'll still have the same problem as 1), meaning that the model architecture won't be suitable and I'll receive the same problem as in 1):

  ans_weight = answer_state_dict['logit_fc.3.weight']
KeyError: 'logit_fc.3.weight'

In any case, if you have a suitable model (that wasn't pre-trained on the GQA) that I can use it would be best. Many thanks.