penghao-wu / vstar

PyTorch Implementation of "V* : Guided Visual Search as a Core Mechanism in Multimodal LLMs"
https://vstar-seal.github.io/
MIT License
497 stars 32 forks source link

Issue with the inference method in the VSMForCausalLM class #17

Open MMoshtaghi opened 2 months ago

MMoshtaghi commented 2 months ago

I was reading your code and noticed something strange in the VSMForCausalLM class, maybe a small bug, but not sure since I haven't tested your code yet. Why does images_clip (preprocessed by CLIPProcessor) is given as images (preprocessed by OwlViTrPocessor) to the forward method ( through self.generate() ), while actually looking at the model_forward() method, it needs both of them !?

def model_forward(
        self,
        images: torch.FloatTensor,
        images_clip: torch.FloatTensor,
        input_ids: torch.LongTensor,
        labels: torch.LongTensor,
        attention_masks: torch.LongTensor,
        offset: torch.LongTensor,
        masks_list: List[torch.FloatTensor],
        label_list: List[torch.Tensor],
        bboxes_labels_list: List[torch.FloatTensor],
        bboxes_valid_list: torch.Tensor,
        masks_valid_list: List[torch.Tensor],
        resize_list: List[tuple],
        inference: bool = False,
        **kwargs,
    ):
...

def inference(
        self,
        images_clip,
        images,
        input_ids,
        resize_list,
        original_size_list,
        max_new_tokens=32,
        tokenizer=None,
        mode = 'vqa'
    ):
        assert mode in ['vqa', 'segmentation', 'detection']
        with torch.no_grad():
            outputs = self.generate(
                images=images_clip, # ????
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                num_beams=1,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )
            output_hidden_states = outputs.hidden_states[-1]
            output_ids = outputs.sequences
penghao-wu commented 2 months ago

In model_forward, if it is inference mode, only images_clip will be used (see https://github.com/penghao-wu/vstar/blob/4ede6647959cfb59eeabd09286adf6a5f9478da0/VisualSearch/model/VSM.py#L236)

MMoshtaghi commented 2 months ago

Thanks for your quick reply! if I understood correctly, I see that you use "images_clip" in the model_forward() to pass it to the "super.forward()" for LLaVA, and use "images" to get the OwlViT image emebddings: https://github.com/penghao-wu/vstar/blob/4ede6647959cfb59eeabd09286adf6a5f9478da0/VisualSearch/model/VSM.py#L201-L219 https://github.com/penghao-wu/vstar/blob/4ede6647959cfb59eeabd09286adf6a5f9478da0/VisualSearch/model/VSM.py#L236-L250

but that's actually why I'm asking the reason for giving "images_clip" as the "images" argument to the self.generate() and to the self.model_forward() , instead of the "images_clip" itself !

https://github.com/penghao-wu/vstar/blob/4ede6647959cfb59eeabd09286adf6a5f9478da0/VisualSearch/model/VSM.py#L450-L458

Shouldn't it be this:

        with torch.no_grad():
            outputs = self.generate(
                images=images,
                images_clip=images_clip,
                input_ids=input_ids,
                max_new_tokens=max_new_tokens,
                num_beams=1,
                output_hidden_states=True,
                return_dict_in_generate=True,
            )

Am I missing something here? Thanks.

penghao-wu commented 2 months ago

During the inference, the forward function called by generate will always go to super().forward() because the "past_key_values" is always provided.

https://github.com/penghao-wu/vstar/blob/4ede6647959cfb59eeabd09286adf6a5f9478da0/VisualSearch/model/llava/model/language_model/llava_llama.py#L137-L163