salesforce / LAVIS

LAVIS - A One-stop Library for Language-Vision Intelligence
BSD 3-Clause "New" or "Revised" License
9.65k stars 943 forks source link

About the lables in the VQA task using OPT models #661

Open dszpr opened 6 months ago

dszpr commented 6 months ago

Hi! I noticed that in forward function in blip2_opt.py, only the questions in the VQA dateset are used. Both the text_input and the target lables are derived from the opt_tokens `

    text = [t + "\n" for t in samples["text_input"]]

    opt_tokens = self.opt_tokenizer(
        text,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=self.max_txt_len,
    ).to(image.device)

    targets = opt_tokens.input_ids.masked_fill(
        opt_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
    )
    if self.prompt:
        targets[:, : self.prompt_length] = -100  # do not apply loss to the prompt

    empty_targets = (
        torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)
    )
    targets = torch.cat([empty_targets, targets], dim=1)

    inputs_embeds = self.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
    inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
    attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

    with self.maybe_autocast():
        outputs = self.opt_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )

` In the VQA task, are the target lables supposed to be the answers in the VQA dataset? But the answers in the VQA datasetare not used in the blip2_opt.py. However, the answers are used as target lables in blip2_t5.py. It really confused. Have you make any change to blip2_opt.py?

chengyuehuang511 commented 2 months ago

I have the same question. @dxli94 Could you please take a look? Besides, I don't quite get what's this for: empty_targets = (torch.ones(attsopt.size(), dtype=torch.long).to(image.device).fill(-100)) targets = torch.cat([empty_targets, targets], dim=1)

chengyuehuang511 commented 2 months ago

I have the same question. @dxli94 Could you please take a look? Besides, I don't quite get what's this for: empty_targets = (torch.ones(attsopt.size(), dtype=torch.long).to(image.device).fill(-100)) targets = torch.cat([empty_targets, targets], dim=1)

I attach a revised version of fine-tuning blip2_opt for VQA tasks here: https://github.com/salesforce/LAVIS/issues/125#issuecomment-2200960668. Can you help me check if it's correct? Thanks!