haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
19.26k stars 2.11k forks source link

process multi images #1292

Open PzWHU opened 5 months ago

PzWHU commented 5 months ago

Describe the issue

Issue:

How can I input multi images

SakuraTroyChen commented 5 months ago

Concat the images path with ','

Dinghaoxuan commented 5 months ago

Concat the images path with ','

What about the prompt? How to edit prompt when multi-images are input? Thank you for your response.

Sprinter1999 commented 5 months ago

I'm also wondering.

SakuraTroyChen commented 5 months ago

Concat the images path with ','

What about the prompt? How to edit prompt when multi-images are input? Thank you for your response.

That's a task-specific problem. Maybe you can add some captions and instruct LLaVA in the appropriate order?

yuejunpeng commented 3 months ago

@SakuraTroyChen I tried it but failed. In detail, I input two images as '--image-file' and cancat them with ',', and the task is to output the image captioning about the two images. However, the model only answer the captioning about the first image. The second image is almost invisible for the model. Maybe I think it's necessary to split the two image tokens with some special token.

anas-zafar commented 2 months ago

Hi @yuejunpeng were you able to solve it? Thanks

SakuraTroyChen commented 2 months ago

@SakuraTroyChen I tried it but failed. In detail, I input two images as '--image-file' and cancat them with ',', and the task is to output the image captioning about the two images. However, the model only answer the captioning about the first image. The second image is almost invisible for the model. Maybe I think it's necessary to split the two image tokens with some special token.

Since the models are not trained on multi-image datasets, the real performance is unpredictable. I think you might need to train the model on specific multi-image datasets before using the multi-image inputs.

anas-zafar commented 2 months ago

@SakuraTroyChen can you please guide me how can I train on multi-image inputs? Thanks

SakuraTroyChen commented 2 months ago

@SakuraTroyChen can you please guide me how can I train on multi-image inputs? Thanks

You can take a look at this tutorial. (It was written in Chinese though)

SakuraTroyChen commented 2 months ago

Dataset format:

{
        "id": "1",
        "image": [
            "images/french_toast/278421.jpg",
            "images/onion_rings/1074382.jpg",
            "images/spring_rolls/419151.jpg",
            "images/carrot_cake/263995.jpg",
            "images/eggs_benedict/44135.jpg"
        ],
        "conversations": [
            {
                "from": "human",
                "value": "<image>\n<image>\n<image>\n<image>\n<image>\nWhat food are these five pictures about?"
            },
            {
                "from": "gpt",
                "value": "french toast, onion rings, spring rolls, carrot cake, eggs benedict"
            }
        ]
    }

Step 1: Modify __getitem__ in LazySupervisedDataset.

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            processor = self.data_args.image_processor
​
            if isinstance(image_file, list):
                image = []
                for img_file in image_file:
                    image.append(Image.open(os.path.join(image_folder, img_file)).convert('RGB'))
            else:
                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')

            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                if isinstance(image, list):
                    image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
                    image = [processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in image]
                else:
                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                if isinstance(image, list):
                    image = [processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in image]
                else:
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]))
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])
​
        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict

Step 2: Modify DataCollatorForSupervisedDataset.

class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
​
    tokenizer: transformers.PreTrainedTokenizer
​
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )
​
        if 'image' in instances[0]:

            images = [instance['image'] for instance in instances]
            if isinstance(images[0], list):
                images = torch.stack([torch.stack(img, dim=0) for img in images], dim = 0)
                batch['images'] = images
            else:
                if all(x is not None and x.shape == images[0].shape for x in images):
                    batch['images'] = torch.stack(images)
                else:
                    batch['images'] = images
​
        return batch
FengLi-ust commented 2 months ago

Hi, LLaVA-Next-Interleave version is out, which naturally supports multi-image interleaved inputs. Please refer to this evaluation code for the input format. It can directly handle the input format you provide.

You can also try our model demo and see the details of the model in this blog.

Dataset format:

{
        "id": "1",
        "image": [
            "images/french_toast/278421.jpg",
            "images/onion_rings/1074382.jpg",
            "images/spring_rolls/419151.jpg",
            "images/carrot_cake/263995.jpg",
            "images/eggs_benedict/44135.jpg"
        ],
        "conversations": [
            {
                "from": "human",
                "value": "<image>\n<image>\n<image>\n<image>\n<image>\nWhat food are these five pictures about?"
            },
            {
                "from": "gpt",
                "value": "french toast, onion rings, spring rolls, carrot cake, eggs benedict"
            }
        ]
    }

Step 1: Modify __getitem__ in LazySupervisedDataset.

def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        if 'image' in sources[0]:
            image_file = self.list_data_dict[i]['image']
            image_folder = self.data_args.image_folder
            processor = self.data_args.image_processor
​
            if isinstance(image_file, list):
                image = []
                for img_file in image_file:
                    image.append(Image.open(os.path.join(image_folder, img_file)).convert('RGB'))
            else:
                image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')

            if self.data_args.image_aspect_ratio == 'pad':
                def expand2square(pil_img, background_color):
                    width, height = pil_img.size
                    if width == height:
                        return pil_img
                    elif width > height:
                        result = Image.new(pil_img.mode, (width, width), background_color)
                        result.paste(pil_img, (0, (width - height) // 2))
                        return result
                    else:
                        result = Image.new(pil_img.mode, (height, height), background_color)
                        result.paste(pil_img, ((height - width) // 2, 0))
                        return result
                if isinstance(image, list):
                    image = [expand2square(img, tuple(int(x*255) for x in processor.image_mean)) for img in image]
                    image = [processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in image]
                else:
                    image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
            else:
                if isinstance(image, list):
                    image = [processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in image]
                else:
                    image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

            sources = preprocess_multimodal(
                copy.deepcopy([e["conversations"] for e in sources]),
                self.data_args)
        else:
            sources = copy.deepcopy([e["conversations"] for e in sources])
        data_dict = preprocess(
            sources,
            self.tokenizer,
            has_image=('image' in self.list_data_dict[i]))
        if isinstance(i, int):
            data_dict = dict(input_ids=data_dict["input_ids"][0],
                             labels=data_dict["labels"][0])
​
        # image exist in the data
        if 'image' in self.list_data_dict[i]:
            data_dict['image'] = image
        elif self.data_args.is_multimodal:
            # image does not exist in the data, but the model is multimodal
            crop_size = self.data_args.image_processor.crop_size
            data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
        return data_dict

Step 2: Modify DataCollatorForSupervisedDataset.

class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""
​
    tokenizer: transformers.PreTrainedTokenizer
​
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances]
                                  for key in ("input_ids", "labels"))
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids,
            batch_first=True,
            padding_value=self.tokenizer.pad_token_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels,
                                                 batch_first=True,
                                                 padding_value=IGNORE_INDEX)
        input_ids = input_ids[:, :self.tokenizer.model_max_length]
        labels = labels[:, :self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )
​
        if 'image' in instances[0]:

            images = [instance['image'] for instance in instances]
            if isinstance(images[0], list):
                images = torch.stack([torch.stack(img, dim=0) for img in images], dim = 0)
                batch['images'] = images
            else:
                if all(x is not None and x.shape == images[0].shape for x in images):
                    batch['images'] = torch.stack(images)
                else:
                    batch['images'] = images
​
        return batch
anas-zafar commented 2 months ago

Thanks @FengLi-ust .

SakuraTroyChen commented 2 months ago

Thanks @FengLi-ust , will look into it. Thanks @SakuraTroyChen, can you guide me please how do I test this model after the Lora weights have been merged?

I think it is enough to follow these 2 steps. In the tests I conducted, the source code performed better than the one with step 3. However, you can still use the Step 3 code as follows:

class PloyLlavaMetaForCausalLM(LlavaMetaForCausalLM, ABC):
    def __init__(self):
        super().__init__()
​
    def prepare_inputs_labels_for_multimodal(
        self, input_ids, position_ids, attention_mask, past_key_values, labels,
        images, image_sizes=None
    ):
        vision_tower = self.get_vision_tower()
        if vision_tower is None or images is None or input_ids.shape[1] == 1:
            return input_ids, position_ids, attention_mask, past_key_values, None, labels
        if images.ndim == 5:
            image_features = torch.stack([self.encode_images(img) for img in images], dim = 0)
        else:
            image_features = self.encode_images(images)
        # image_features = self.encode_images(images)
​
        # TODO: image start / end is not implemented here to support pretraining.
        if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
            raise NotImplementedError
​
        # Let's just add dummy tensors if they do not exist,
        # it is a headache to deal with None all the time.
        # But it is not ideal, and if you have a better idea,
        # please open an issue / submit a PR, thanks.
        _labels = labels
        _position_ids = position_ids
        _attention_mask = attention_mask
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
        else:
            attention_mask = attention_mask.bool()
        if position_ids is None:
            position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
        if labels is None:
            labels = torch.full_like(input_ids, IGNORE_INDEX)
​
        # remove the padding using attention_mask -- FIXME
        _input_ids = input_ids
        input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
        labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
​
        new_input_embeds = []
        new_labels = []

        for batch_idx, cur_input_ids in enumerate(input_ids):
            cur_image_idx = 0
            num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
            if image_features.ndim == 4:
                batch_image_features = image_features[batch_idx]
            else:
                batch_image_features = image_features
            if num_images == 0:
                cur_image_features = batch_image_features[cur_image_idx]
                cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
                cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
                new_input_embeds.append(cur_input_embeds)
                new_labels.append(labels[batch_idx])
                cur_image_idx += 1
                continue
​
            image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
            cur_input_ids_noim = []
            cur_labels = labels[batch_idx]
            cur_labels_noim = []
            for i in range(len(image_token_indices) - 1):
                cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
                cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
            split_sizes = [x.shape[0] for x in cur_labels_noim]
            cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
            cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
            cur_new_input_embeds = []
            cur_new_labels = []
​
            for i in range(num_images + 1):
                cur_new_input_embeds.append(cur_input_embeds_no_im[i])
                cur_new_labels.append(cur_labels_noim[i])
                if i < num_images:
                    cur_image_features = batch_image_features[cur_image_idx]
                    cur_image_idx += 1
                    cur_new_input_embeds.append(cur_image_features)
                    cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
​
            cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds]
​
            cur_new_input_embeds = torch.cat(cur_new_input_embeds)
            cur_new_labels = torch.cat(cur_new_labels)
​
            new_input_embeds.append(cur_new_input_embeds)
            new_labels.append(cur_new_labels)
​
        # Truncate sequences to max length as image embeddings can make the sequence longer
        tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
        if tokenizer_model_max_length is not None:
            new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
            new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
​
        # Combine them
        max_len = max(x.shape[0] for x in new_input_embeds)
        batch_size = len(new_input_embeds)
​
        new_input_embeds_padded = []
        new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
        attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
        position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
​
        for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
            cur_len = cur_new_embed.shape[0]
            if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
                new_input_embeds_padded.append(torch.cat((
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
                    cur_new_embed
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, -cur_len:] = cur_new_labels
                    attention_mask[i, -cur_len:] = True
                    position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
            else:
                new_input_embeds_padded.append(torch.cat((
                    cur_new_embed,
                    torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
                ), dim=0))
                if cur_len > 0:
                    new_labels_padded[i, :cur_len] = cur_new_labels
                    attention_mask[i, :cur_len] = True
                    position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
​
        new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
​
        if _labels is None:
            new_labels = None
        else:
            new_labels = new_labels_padded
​
        if _attention_mask is None:
            attention_mask = None
        else:
            attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
​
        if _position_ids is None:
            position_ids = None
​
        return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
​
​
​
class PloyLlavaLlamaForCausalLM(LlamaForCausalLM, PloyLlavaMetaForCausalLM):
    config_class = LlavaConfig
​
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlavaLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
​
        # Initialize weights and apply final processing
        self.post_init()
​
    def get_model(self):
        return self.model
​
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        images: Optional[torch.FloatTensor] = None,
        image_sizes: Optional[List[List[int]]] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
​
        if inputs_embeds is None:
            (
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                inputs_embeds,
                labels
            ) = self.prepare_inputs_labels_for_multimodal(
                input_ids,
                position_ids,
                attention_mask,
                past_key_values,
                labels,
                images,
                image_sizes
            )
​
        return super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
​
    @torch.no_grad()
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        images: Optional[torch.Tensor] = None,
        image_sizes: Optional[torch.Tensor] = None,
        **kwargs,
    ) -> Union[GenerateOutput, torch.LongTensor]:
        position_ids = kwargs.pop("position_ids", None)
        attention_mask = kwargs.pop("attention_mask", None)
        if "inputs_embeds" in kwargs:
            raise NotImplementedError("`inputs_embeds` is not supported")
​
        if images is not None:
            (
                inputs,
                position_ids,
                attention_mask,
                _,
                inputs_embeds,
                _
            ) = self.prepare_inputs_labels_for_multimodal(
                inputs,
                position_ids,
                attention_mask,
                None,
                None,
                images,
                image_sizes=image_sizes
            )
        else:
            inputs_embeds = self.get_model().embed_tokens(inputs)
​
        return super().generate(
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            **kwargs
        )
​
    def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
                                      inputs_embeds=None, **kwargs):
        images = kwargs.pop("images", None)
        image_sizes = kwargs.pop("image_sizes", None)
        inputs = super().prepare_inputs_for_generation(
            input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
        )
        if images is not None:
            inputs['images'] = images
        if image_sizes is not None:
            inputs['image_sizes'] = image_sizes
        return inputs
SakuraTroyChen commented 2 months ago

Got it . @SakuraTroyChen How do I test this model after merging the lora weights?

Load it and test it on your downstream task. It's task-specific. You can follow the evaluation tutorial in https://github.com/haotian-liu/LLaVA

anas-zafar commented 2 months ago

@SakuraTroyChen when add more than 2 images in the prompt I get the following error. Can you please guide me how to solve this? Thanks

[/content/LLaVA/llava/model/llava_arch.py](https://localhost:8080/#) in prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
    258                 cur_new_labels.append(cur_labels_noim[i])
    259                 if i < num_images:
--> 260                     cur_image_features = image_features[cur_image_idx]
    261                     cur_image_idx += 1
    262                     cur_new_input_embeds.append(cur_image_features)

IndexError: index 3 is out of bounds for dimension 0 with size 3
lqqyyy commented 1 month ago

@SakuraTroyChen 您好,只修改了第一二步后 用lora微调模型,出现loss值快速下降 几个step就到0的情况,用了2w作用的数据量。请问我是否遗漏了哪些步骤?

SakuraTroyChen commented 1 month ago

@SakuraTroyChen 您好,只修改了第一二步后 用lora微调模型,出现loss值快速下降 几个step就到0的情况,用了2w作用的数据量。请问我是否遗漏了哪些步骤?

是用的llava原先的lora微调脚本吗?我是没出现这种情况,可能得看看你的learning rate之类的超参?

anas-zafar commented 1 month ago

hi @SakuraTroyChen, sorry for pinging you again. I am trying to test the LoRA merged model. I am kind of stuck on this.

args = type('Args', (), {
    "model_path": fine_tuned_model_path,
    "model_base": None,
    "model_name": get_model_name_from_path(fine_tuned_model_path),
    "query": prompt,
    "conv_mode": None,
    "image_file": "test_6859.png,test_6843.png",
    "sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 512
})() 

prompt = "USER: <image>\n The diagram represents a factory floor. <image> \n How many actions are depicted in the diagram? ASSISTANT:"

IndexError                                Traceback (most recent call last)
[<ipython-input-5-4a73c0203166>](https://localhost:8080/#) in <cell line: 38>()
     36 
     37 
---> 38 eval_model(args)
     39 # print("args.image_file before eval_model:", args.image_file)
     40 # eval_model(args)

3 frames
[/content/LLaVA/llava/model/llava_arch.py](https://localhost:8080/#) in prepare_inputs_labels_for_multimodal(self, input_ids, position_ids, attention_mask, past_key_values, labels, images, image_sizes)
    258                 cur_new_labels.append(cur_labels_noim[i])
    259                 if i < num_images:
--> 260                     cur_image_features = image_features[cur_image_idx]
    261                     cur_image_idx += 1
    262                     cur_new_input_embeds.append(cur_image_features)

IndexError: index 2 is out of bounds for dimension 0 with size 2
SakuraTroyChen commented 1 month ago

@anas-zafar I didn't come across this situation. Maybe you need to use ipdb to print the vectors and debug here.