salesforce / LAVIS

LAVIS - A One-stop Library for Language-Vision Intelligence
BSD 3-Clause "New" or "Revised" License
9.66k stars 944 forks source link

[BLIP2]: Low accuracy of zeroshot VQA of BLIP2-opt-2.7b #239

Open YuanLiuuuuuu opened 1 year ago

YuanLiuuuuuu commented 1 year ago

Hi,

Thank you for your great work BLIP2. I find there is no zeroshot VQA evaluation code for BLIP2-OPT, so I create one, refering to the code of FLAN-T5. However, the accuracy is very low. I will be very grateful if you can provide the code for BLIP2-OPT VQA zeroshot. Thanks in advance!

image
YuanLiuuuuuu commented 1 year ago

Here is my evaluation code for OPT, refering to that of FLAN-T5:

def predict_answers(self,
                        samples,
                        num_beams=5,
                        max_len=10,
                        min_len=1,
                        prompt="",
                        length_penalty=-1,
                        **kwargs):

        image = samples["image"]
        with self.maybe_autocast():
            image_embeds = self.ln_vision(self.visual_encoder(image))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1],
                                dtype=torch.long).to(image.device)

        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_opt = self.opt_proj(query_output.last_hidden_state)
        atts_opt = torch.ones(inputs_opt.size()[:-1],
                              dtype=torch.long).to(image.device)

        if isinstance(samples["text_input"], str):
            samples["text_input"] = [samples["text_input"]]
        if prompt:
            text_input = [
                prompt.format(question) for question in samples["text_input"]
            ]
        else:
            text_input = samples["text_input"]

        input_tokens = self.opt_tokenizer(text_input,
                                          padding="longest",
                                          return_tensors="pt").to(image.device)

        encoder_atts = torch.cat([atts_opt, input_tokens.attention_mask],
                                 dim=1)

        with self.maybe_autocast(dtype=torch.bfloat16):
            inputs_embeds = self.opt_model.model.decoder.embed_tokens(
                input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)

            outputs = self.opt_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                do_sample=False,
                num_beams=num_beams,
                max_new_tokens=max_len,
                min_length=min_len,
                length_penalty=length_penalty,
            )
            output_text = self.opt_tokenizer.batch_decode(
                outputs, skip_special_tokens=True)

        if self._apply_lemmatizer:
            output_text = self._lemmatize(output_text)

        return output_text

    def _lemmatize(self, answers):
        def apply(answer):
            doc = self.lemmatizer(answer)

            words = []
            for token in doc:
                if token.pos_ in ["NOUN", "VERB"]:
                    words.append(token.lemma_)
                else:
                    words.append(token.text)
            answer = " ".join(words)

            return answer

        return [apply(answer) for answer in answers]

    @property
    def lemmatizer(self):
        if self._lemmatizer is None:
            try:
                import spacy

                self._lemmatizer = spacy.load("en_core_web_sm")
            except ImportError:
                logging.error("""
                    Please install spacy and en_core_web_sm model to apply lemmatization.
                    python -m spacy download en_core_web_sm
                    OR
                    import spacy.cli
                    spacy.cli.download("en_core_web_sm")
                    """)
                exit(1)

        return self._lemmatizer
YuanLiuuuuuu commented 1 year ago

In addition, I also try to figure out why the zeroshot VQAv2 accuracy is so low for OPT, and inspect that the output of OPT is always a sentence. However, the target answer is always a word, and the existing VQA evaluation metric requires the exact match between the output and target answer. So I wonder whether the evaluation metric for OPT is different from that of FLAN-T5?

mactavish91 commented 1 year ago

@YuanLiuuuuuu Can flan-t5 model achieve a normal result?

ThreeSR commented 1 year ago

@YuanLiuuuuuu If you directly use the 'generate' function to generate answers via opt, I think the overall accuracy won't be too low. But even if using the 'generate' function, I still cannot achieve the same performance as paper (my 35.48 vs 54.3 in paper, model: pretrain opt 6.7b, vqa val split). @dxli94 @LiJunnan1992 I hope the author can provide more details about how to evaluate zero-shot vqa. I would like to know if the author constrains the generation over 3129 most frequent answers. Thanks.

LiJunnan1992 commented 1 year ago

The zero-shot evaluation of VQA with OPT follows the same protocol as FlanT5. We have not released the code because it will conflict with transformer>4.25.

For transformer<=4.25, please modify the generate function into predict_answers, and implement these two key operations:

Thank you.

ThreeSR commented 1 year ago

@LiJunnan1992 Thanks a lot for your immediate reply. I will try to do that.

Daniel-van-Dijk commented 1 year ago

@LiJunnan1992 Have you been able to solve it? I'm struggling with the same issue

YuanLiuuuuuu commented 1 year ago

@LiJunnan1992 Have you been able to solve it? I'm struggling with the same issue Set self.opt_tokenizer.padding_side = "left" solves this problem

YuanLiuuuuuu commented 1 year ago

@YuanLiuuuuuu Can flan-t5 model achieve a normal result?

Yes

Daniel-van-Dijk commented 1 year ago

@YuanLiuuuuuu I tried your code but end up with: ValueError: If inputs_embeds is passed as model-specific keyword input then model has to be an encoder-decoder and not a OPTForCausalLM at outputs = self.opt_model.generate( inputs_embeds=inputs_embeds, ...), Have you encountered the same error? Thanks!

LiJunnan1992 commented 1 year ago

We have made an update to BLIP-2 OPT models so that they can work with the latest transformers with version>=4.27. Now you can directly run run_scripts/blip2/eval/validate_vqa_zeroshot_opt.sh for VQA evaluation. The results will have some difference with the ones reported in the paper due to different implementations. Thank you.

dszpr commented 6 months ago

Hi! @LiJunnan1992 I noticed that in forward function in blip2_opt.py, only the questions in the VQA dateset are used. Both the text_input and the target lables are derived from the opt_tokens `

    text = [t + "\n" for t in samples["text_input"]]

    opt_tokens = self.opt_tokenizer(
        text,
        return_tensors="pt",
        padding="longest",
        truncation=True,
        max_length=self.max_txt_len,
    ).to(image.device)

    targets = opt_tokens.input_ids.masked_fill(
        opt_tokens.input_ids == self.opt_tokenizer.pad_token_id, -100
    )
    if self.prompt:
        targets[:, : self.prompt_length] = -100  # do not apply loss to the prompt

    empty_targets = (
        torch.ones(atts_opt.size(), dtype=torch.long).to(image.device).fill_(-100)
    )
    targets = torch.cat([empty_targets, targets], dim=1)

    inputs_embeds = self.opt_model.model.decoder.embed_tokens(opt_tokens.input_ids)
    inputs_embeds = torch.cat([inputs_opt, inputs_embeds], dim=1)
    attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1)

    with self.maybe_autocast():
        outputs = self.opt_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            return_dict=True,
            labels=targets,
        )

` In the VQA task, are the target lables supposed to be the answers in the VQA dataset? But the answers in the VQA datasetare not used in the blip2_opt.py. However, the answers are used as target lables in blip2_t5.py. It really confused. Have you make any change to blip2_opt.py?