haotian-liu / LLaVA

[NeurIPS'23 Oral] Visual Instruction Tuning (LLaVA) built towards GPT-4V level capabilities and beyond.
https://llava.hliu.cc
Apache License 2.0
20.16k stars 2.22k forks source link

[Usage] Batch inference with Llava 1.5 #709

Open kimihailv opened 1 year ago

kimihailv commented 1 year ago

Describe the issue

Currenty, only inference with batch_size=1 is possible. If I undestood correctly, these things should be changed to make batch inference:

  1. position_ids should be shifted, because of left padding
  2. Attention mask should be passed and transformed for multimodal forward

Maybe someone has managed to adapt the code?

rabiulcste commented 1 year ago

Here's a processor that I wrote to make it work.

from LLaVA.llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from LLaVA.llava.conversation import conv_templates
from LLaVA.llava.mm_utils import tokenizer_image_token

class LlaVaProcessor:
    def __init__(self, tokenizer, image_processor, mm_use_im_start_end):
        self.mm_use_im_start_end = mm_use_im_start_end
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.conv_mode = "llava_v1"

    def load_demo_images(image_files: Union[List[str], str]):
        if type(image_files) is list:
            out = []
            for image_file in image_files:
                image = Image.open(image_file).convert("RGB")
                out.append(image)
        else:
            out = Image.open(image_files).convert("RGB")
        return out

    # TODO: refactor this, not working
    def get_processed_tokens_demo(self, text: str, image_files: Union[List[str], str]):
        if self.mm_use_im_start_end:
            qs = (
                qs
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
            )
        else:
            qs = (
                qs
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
            )

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        images = self.load_demo_images(image_files)
        image_tensor = torch.stack(
            [self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
        )

        input_ids = (
            tokenizer_image_token(text, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
        )

        return image_tensor, input_ids

    def format_text(self, text: str):
        if self.mm_use_im_start_end:
            text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + text
        else:
            text = DEFAULT_IMAGE_TOKEN + "\n" + text

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        text = conv.get_prompt()

        return text

    def load_image(self, image_path: str):
        return Image.open(image_path).convert("RGB")

    @staticmethod
    def pad_sequence_to_max_length(sequence, max_length, padding_value=0):
        """Pad a sequence to the desired max length."""
        if len(sequence) >= max_length:
            return sequence
        return torch.cat([torch.full((max_length - len(sequence),), padding_value, dtype=sequence.dtype), sequence])

    def get_processed_tokens(self, text: str, image_path: str):
        prompt = self.format_text(text)
        image = self.load_image(image_path)

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
        image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]

        return image_tensor, input_ids

    def get_processed_tokens_batch(self, batch_text: List[str], image_paths: List[str]):
        prompt = [self.format_text(text) for text in batch_text]
        images = [self.load_image(image_path) for image_path in image_paths]

        batch_input_ids = [
            tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in prompt
        ]

        # Determine the maximum length of input_ids in the batch
        max_len = max([len(seq) for seq in batch_input_ids])
        # Pad each sequence in input_ids to the max_len
        padded_input_ids = [self.pad_sequence_to_max_length(seq.squeeze(), max_len) for seq in batch_input_ids]
        batch_input_ids = torch.stack(padded_input_ids)

        batch_image_tensor = self.image_processor(images, return_tensors="pt")["pixel_values"]

        return batch_image_tensor, batch_input_ids

You can now do inference

                from LLaVA.llava.conversation import (SeparatorStyle,
                                                      conv_templates)
                from LLaVA.llava.mm_utils import KeywordsStoppingCriteria

                conv = conv_templates[processor.conv_mode].copy()
                stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
                keywords = [stop_str]
                stopping_criteria = (
                    [KeywordsStoppingCriteria(keywords, processor.tokenizer, input_ids)]
                    if conv.version == "v0"
                    else None
                )
                input_ids = batch["input_ids"]
                image_tensor = batch["image_tensors"]
                input_ids = input_ids.cuda()

                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.half().cuda(),
                    num_beams=self.args.num_beams,
                    max_new_tokens=self.args.max_length,
                    length_penalty=self.args.length_penalty,
                    use_cache=True,
                    stopping_criteria=stopping_criteria,
                    do_sample=self.args.do_sample,
                    temperature=self.args.temperature,
                    num_return_sequences=self.args.num_return_sequences,
                )
                generated_outputs = processor.tokenizer.batch_decode(
                    output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
                )
                generated_outputs = [out.strip() for out in generated_outputs]
                generated_outputs = [
                    out[: -len(stop_str)] if out.endswith(stop_str) else out for out in generated_outputs
                ]

You can also check my vqa-prompting codebase for full support!

NielsRogge commented 11 months ago

Hi,

Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.

david-vectorflow commented 6 months ago

In case anyone else finds this, here is a sample of working batch inference code based on the link above

        prompt_temp = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\n{}<|im_end|><|im_start|>assistant\n"

        prompts=[]
        images = []

        for user_question, base64_image in zip(user_questions, images_data):
            prompt = prompt_temp.format(user_question)
            prompts.append(prompt)

            image_data = base64.b64decode(base64_image)
            image = Image.open(BytesIO(image_data))
            images.append(image)

        # Perform batch inference
        inputs = self.processor(prompts, images=images, return_tensors="pt", padding=True).to("cuda:0")
        output = self.model.generate(**inputs, max_new_tokens=4000)

        answer = self.processor.batch_decode(output, skip_special_tokens=True)
g8a9 commented 6 months ago

Hi,

Batched inference with LLaVa is supported in Hugging Face Transformers. See here for an example: https://github.com/huggingface/transformers/blob/a49f4acab3c1eea82907e12f82eafbd4673deb39/tests/models/llava/test_modeling_llava.py#L245.

Hey, @NielsRogge I've stumbled upon this issue today. It seems that the same code does not work for LlavaNextForConditionalGeneration. Is batched inference for LlavaNext models supported in some other ways?

For reference, it crashed when trying to stacking new_image_features

File ~/miniconda3/envs/vlm_safety_eval/lib/python3.10/site-packages/transformers/models/llava_next/modeling_llava_next.py:553, in LlavaNextForConditionalGeneration.forward(self, input_ids, pixel_values, image_sizes, attention_mask, position_ids, past_key_values, inputs_embeds, vision_feature_layer, vision_feature_select_strategy, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    551         image_feature = torch.cat((image_feature, self.image_newline[None]), dim=0)
    552     new_image_features.append(image_feature)
--> 553 image_features = torch.stack(new_image_features, dim=0)
    555 inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
    556     image_features, inputs_embeds, input_ids, attention_mask, labels
    557 )
    558 if labels is None:

RuntimeError: stack expects each tensor to be equal size, but got [2144, 4096] at entry 0 and [2340, 4096] at entry 1
NielsRogge commented 6 months ago

Yes I'm aware of that, this is being addressed in https://github.com/huggingface/transformers/pull/29850

It will be part of the next Transformers release!

hxhcreate commented 2 months ago
  if "llava" in self.model_name.lower() and 'hf' in self.model_name.lower():
            questions = [question.replace("<image>", "") for question in questions]  
            images = None
            if image_paths: 
                images = self.load_images(image_paths)
                conversations = [[
                    { 
                        "role": "user","content": 
                        [
                            {"type": "image"},
                            {"type": "text", "text": question},

                        ],
                    },
                ] for question in questions]
            prompts = [self.processor.apply_chat_template(conv, add_generation_prompt=True) for conv in conversations]
            inputs = self.processor(text=prompts, images=images, padding=True, truncation=True, max_length=self.args.max_length, return_tensors='pt').to(self.device, torch.float16)
            with torch.no_grad():
                generated_ids = self.model.generate(**inputs, **self.generation_config)
                generated_ids_trimmed = [
                    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                outputs = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            outputs = [output.strip() for output in outputs]

When I run the above batch_inference_code, I got the following error and I cannot understand

RuntimeError:The size of tensor a (2955)must match the size of tensor b (28)at non-singleton dimension 0

NielsRogge commented 2 months ago

Update here, we now added batch inference in the docs:

hxhcreate commented 2 months ago

Thanks for your reply. But when I running the sample code here: https://huggingface.co/docs/transformers/main/en/model_doc/llava_next#multi-image-inference I still get the following error:

image

I'm using the latest transformer 4.45.dev

NielsRogge commented 2 months ago

Thanks for flagging. Will ping @zucchini-nlp here.

Note: all doc code snippets of Transformers get tested, but some are excluded due to their size. LLaVa-NeXT is such an example. We could look into how we can also ensure code snippets of larger models are automatically tested

zucchini-nlp commented 2 months ago

@hxhcreate @NielsRogge yes, that is a known issue and I merged a fix few days ago. Unfortunately refactoring broken some things, let me know if updating to the latest main solves the issue

copperwiring commented 1 month ago

Anyone knowhow to do it without hugging face model. This is what I have so far

    for prompt in prompts_batch:
        # Set args.query to the specific prompt in the batch
        args.query = prompt

        # Generate the prompt for each input in the batch, with the correct image handling
        qs = get_prompt(args, model)

        # Create a new conversation template for each prompt in the batch
        conv = conv_templates[args.conv_mode].copy()
        conv.append_message(conv.roles[0], qs)
        conv.append_message(conv.roles[1], None)

        # Add the complete prompt for this instance to the batch
        batched_prompts.append(conv.get_prompt())

    # max length for padding
    max_len = max([len(tokenizer.encode(prompt)) for prompt in batched_prompts])

    tokenizer.padding_side = "left"
    tokenizer.model_max_length = max_len

    # Tokenize the batch of prompts
    tokenized_prompts = [
        tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
        for prompt in batched_prompts
    ]

    input_ids = torch.cat(tokenized_prompts, dim=0).cuda()

    # Process images if provided (batch image loading and processing)
    if img_files_batch:
        # For each batch, parse image files, load them, and process
        image_files_batch = [image_parser(img_files, args.sep) for img_files in img_files_batch]
        images = [load_images(image_files) for image_files in image_files_batch]
        flat_images = [item for sublist in images for item in sublist]
        images_tensor = process_images(flat_images, image_processor, model.config).to(model.device, dtype=torch.float16)
        image_sizes = [img.size for img in flat_images]
    else:
        images_tensor = None
        image_sizes = None

    attention_mask = torch.ones_like(input_ids)

    with torch.inference_mode(), torch.cuda.amp.autocast():
        outputs = model.forward(
            input_ids=input_ids, 
            images=None if images_tensor is None else images_tensor,
            image_sizes=image_sizes,
            attention_mask=attention_mask
            )

    logits = outputs.logits[:, -1, :]  # Get the logits for the last token position
    probabilities = F.softmax(logits, dim=-1).squeeze()

This is what I have so far but it doesn't add any padding tokens when I print tokenized_prompts

KevinXu-01 commented 1 month ago

Here's a processor that I wrote to make it work.

from LLaVA.llava.constants import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from LLaVA.llava.conversation import conv_templates
from LLaVA.llava.mm_utils import tokenizer_image_token

class LlaVaProcessor:
    def __init__(self, tokenizer, image_processor, mm_use_im_start_end):
        self.mm_use_im_start_end = mm_use_im_start_end
        self.tokenizer = tokenizer
        self.image_processor = image_processor
        self.conv_mode = "llava_v1"

    def load_demo_images(image_files: Union[List[str], str]):
        if type(image_files) is list:
            out = []
            for image_file in image_files:
                image = Image.open(image_file).convert("RGB")
                out.append(image)
        else:
            out = Image.open(image_files).convert("RGB")
        return out

    # TODO: refactor this, not working
    def get_processed_tokens_demo(self, text: str, image_files: Union[List[str], str]):
        if self.mm_use_im_start_end:
            qs = (
                qs
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
                + "\n"
                + DEFAULT_IM_START_TOKEN
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + DEFAULT_IM_END_TOKEN
            )
        else:
            qs = (
                qs
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
                + "\n"
                + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
            )

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        images = self.load_demo_images(image_files)
        image_tensor = torch.stack(
            [self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0] for image in images]
        )

        input_ids = (
            tokenizer_image_token(text, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda()
        )

        return image_tensor, input_ids

    def format_text(self, text: str):
        if self.mm_use_im_start_end:
            text = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + text
        else:
            text = DEFAULT_IMAGE_TOKEN + "\n" + text

        conv = conv_templates[self.conv_mode].copy()
        conv.append_message(conv.roles[0], text)
        conv.append_message(conv.roles[1], None)
        text = conv.get_prompt()

        return text

    def load_image(self, image_path: str):
        return Image.open(image_path).convert("RGB")

    @staticmethod
    def pad_sequence_to_max_length(sequence, max_length, padding_value=0):
        """Pad a sequence to the desired max length."""
        if len(sequence) >= max_length:
            return sequence
        return torch.cat([torch.full((max_length - len(sequence),), padding_value, dtype=sequence.dtype), sequence])

    def get_processed_tokens(self, text: str, image_path: str):
        prompt = self.format_text(text)
        image = self.load_image(image_path)

        input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0)
        image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]

        return image_tensor, input_ids

    def get_processed_tokens_batch(self, batch_text: List[str], image_paths: List[str]):
        prompt = [self.format_text(text) for text in batch_text]
        images = [self.load_image(image_path) for image_path in image_paths]

        batch_input_ids = [
            tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") for prompt in prompt
        ]

        # Determine the maximum length of input_ids in the batch
        max_len = max([len(seq) for seq in batch_input_ids])
        # Pad each sequence in input_ids to the max_len
        padded_input_ids = [self.pad_sequence_to_max_length(seq.squeeze(), max_len) for seq in batch_input_ids]
        batch_input_ids = torch.stack(padded_input_ids)

        batch_image_tensor = self.image_processor(images, return_tensors="pt")["pixel_values"]

        return batch_image_tensor, batch_input_ids

You can now do inference

                from LLaVA.llava.conversation import (SeparatorStyle,
                                                      conv_templates)
                from LLaVA.llava.mm_utils import KeywordsStoppingCriteria

                conv = conv_templates[processor.conv_mode].copy()
                stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
                keywords = [stop_str]
                stopping_criteria = (
                    [KeywordsStoppingCriteria(keywords, processor.tokenizer, input_ids)]
                    if conv.version == "v0"
                    else None
                )
                input_ids = batch["input_ids"]
                image_tensor = batch["image_tensors"]
                input_ids = input_ids.cuda()

                output_ids = model.generate(
                    input_ids,
                    images=image_tensor.half().cuda(),
                    num_beams=self.args.num_beams,
                    max_new_tokens=self.args.max_length,
                    length_penalty=self.args.length_penalty,
                    use_cache=True,
                    stopping_criteria=stopping_criteria,
                    do_sample=self.args.do_sample,
                    temperature=self.args.temperature,
                    num_return_sequences=self.args.num_return_sequences,
                )
                generated_outputs = processor.tokenizer.batch_decode(
                    output_ids[:, input_ids.shape[1] :], skip_special_tokens=True
                )
                generated_outputs = [out.strip() for out in generated_outputs]
                generated_outputs = [
                    out[: -len(stop_str)] if out.endswith(stop_str) else out for out in generated_outputs
                ]

You can also check my vqa-prompting codebase for full support!

Thank you for your sharing. It solved my problem!!!