Closed hu-my closed 7 months ago
Hi, thanks for pointing out this. Can you pull the latest changes and try if it works?
Thanks for reply! I have pulled the latest changes, but this error still happens:
I find that x_qry at Line 164 in meta_trainer.py is not a tensor but a list, which may cause this error.
Do the order of passed parameters for meta.finetuning(x_spt, y_spt, y_spt_mask, id_spt, x_qry, y_qry, y_qry_mask, id_qry)
(at Line 114 in main_train.py) not correspond to that of finetuning(x_spt, y_spt, y_spt_mask, x_qry, y_qry, y_qry_mask, qry_answer, q_img_id)
(at Line 145 in meta_trainer.py)?
It seems that meta.finetuning function passes id_spt
to the argument x_qry
of finetuning function.
Hi, can you try pulling and running again?
Sorry for the late response. I have pulled the latest changes and run again. And I solved the previous errors with additional modifications: replace for x_spt, y_spt, y_spt_mask, x_qry, y_qry, y_qry_mask, y_qry_answer, qry_img_id in db_test:
in Line 107 at main_train.py as for x_spt, y_spt, y_spt_mask, id_spt, x_qry, y_qry, y_qry_mask, qry_img_id in db_test:
, then add y_qry_answer = y_qry
in Line 116.
Hi, thanks for this interesting work. But when I try to train the multimodal meta-learner with the provided codes, I have encountered an error at step 400 during finetuning process: This error happens at Line 164 in meta_trainer.py:
logits_q, pred_tokens = model(x_qry, y_qry, y_qry_mask, list(model.mapper_net.parameters()), is_finetuning=True)
because the forward function (at Line 47 in meta_learner.py) does not have is_finetuning parameter:def forward(self, image, tokens: torch.Tensor, mask: Optional[torch.Tensor] = None, fast_weights=None, labels: Optional[torch.Tensor] = None, get_pred_tokens=True):
Besides, I find that the passed parameters of meta.finetuning() (at Line 114 in main_train.py) does not correspond to the defined arguments of finetuning() (at Line 145 in meta_trainer.py), which causes another error after I simply remove is_finetuning parameter: So, how can I fix these errors?