zjysteven / lmms-finetune

A minimal codebase for finetuning large multimodal models, supporting llava-1.5/1.6, llava-interleave, llava-next-video, llava-onevision, qwen-vl, qwen2-vl, phi3-v etc.
Apache License 2.0
162 stars 21 forks source link

Trying to understand the Data Collator for compute_metrics evaluation #33

Closed KevinH48264 closed 2 months ago

KevinH48264 commented 2 months ago

During training, I'd like to use HuggingFace's compute_metrics in Trainer to evaluate for things such as inference time "Exact Match" accuracies, but it looks like data collator only supports generation with teacher forcing for now.

Currently without evaluating during training, I'd use the inference code with "add_generation_prompt" = True as such:

# Prepare the conversation prompt
plan_conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": instruction},
            {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"}, {"type": "image"}
        ],
    }
]

# Create the prompt and process the input
prompt = processor.apply_chat_template(plan_conversation, add_generation_prompt=True)

# Run prompt through model and evaluate for exact match

And then my compute_metrics function I would pass to Trainer is:

def compute_metrics(eval_pred: EvalPrediction):
    label_ids = eval_pred.label_ids
    pred_ids = eval_pred.predictions[0]
    input_ids = eval_pred.inputs if hasattr(eval_pred, 'inputs') else None

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    if input_ids is not None:
        input_str = processor.tokenizer.batch_decode(input_ids, skip_special_tokens=True)
    else:
        input_str = [""] * len(pred_str)
# ... <compute metrics on decoded string> ...

The main problem is that pred_str is not the typical expected output of just the assistant response like during inference time, but pred_ids seems to be of size (10, 7640). Decoded pred by printing pred_str above is pasted below.

Effectively, I believe it's because of the Data Collator which is also for evaluation predictions using teacher forcing where it conditions on ground truth previous tokens instead of generated tokens which makes sense for calculating evaluation loss, but doesn't reflect inference time behavior.

My main question is if there are any ideas for how I can change Data Collator and train.py in the code base such that it doesn't use teacher forcing for evaluation and would pass in the expected pred_ids to compute_metrics? I think if others end up using compute_metrics, they'd also want inference time behavior as well, so maybe the collator can change depending on if compute_metrics is passed? I'm not too sure right now.

Thanks again for the great work!

Printed pred:

image

Label: <something short like a caption for the image, which is what I expected pred_str to be>

zjysteven commented 2 months ago

On a high level I would imagine this would need a inference-dedicated collator which will be passed to the eval data loader, as like you mentioned the current collator uses ground-truth answers and does not encode any inference-time behaviors.

I wouldn't be available to work on this until late next week, but would be happy to discuss and provide necessary instructions.

KevinH48264 commented 2 months ago

Ah okay, looking at the inference dedicated collator for LLaVAInterleave, are these the major things I should probably edit? Or are there other things so that I can get an actual inference time generate / predict tokens? (Changes to original LLaVAInterleaveDataCollator are commented with "EDIT")

class LLaVAInterleaveInferenceDataCollator(BaseDataCollator):
    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        # images
        vision_inputs = dict()
        flattened_images: List[PIL.Image.Image] = [x for instance in instances for x in instance["images"]]
        if len(flattened_images) > 0:
            vision_inputs.update(**self.processor.image_processor(flattened_images, return_tensors="pt"))
        num_images: List[int] = [len(instance["images"]) for instance in instances]

        # texts
        # the dataset implementation assume conversations are [user, assistant, user, assistant, ...]
        system_prompts: List[Union[str, None]] = [instance["system_prompt"] for instance in instances]
        conversations: List[List] = [instance["conversations"] for instance in instances]
        max_len = self.tokenizer.model_max_length

        total_image_tokens = 0
        input_ids = []
        labels = []

        for cur_num_images, system_prompt, cur_convs in zip(num_images, system_prompts, conversations):
            cur_num_image_tokens = 0
            cur_input_ids = []
            cur_labels = []

            cur_text = []
            if system_prompt is not None:
                cur_text.append({
                    "role": "system",
                    "content": [{"type": "text", "text": system_prompt}]
                })

            for i, text in enumerate(cur_convs):
                if i % 2 == 0:
                    num_image_tokens = len([m.start() for m in re.finditer("<image>", text)])
                    cur_num_image_tokens += num_image_tokens
                    total_image_tokens += num_image_tokens

                    # .strip(): whitespaces and newlines are handled by chat_template
                    text = text.replace("<image>", "").strip()

                    cur_text.append({
                        "role": "user",
                        "content": [{"type": "text", "text": text}] + \
                            [{"type": "image"}] * num_image_tokens
                    })
                elif i != len(cur_convs) - 1: # EDIT - do not include last assistant message to allow for generation
                    cur_text.append({
                        "role": "assistant",
                        "content": [{"type": "text", "text": text}]
                    })

            assert cur_num_image_tokens == cur_num_images, "Not all images were used"

            temp = self.processor.apply_chat_template(
                cur_text,
                add_generation_prompt=True, # EDIT - allow for generation
                tokenize=True,
                return_assistant_tokens_mask=False, # EDIT - do not mask assistant tokens
                return_dict=True,
                return_tensors="pt",
                truncation=False # the assistant tokens mask seems wrong when truncation is enabled
            )
            cur_input_ids = temp["input_ids"]
            # manual truncation
            if cur_input_ids.shape[1] > max_len:
                cur_input_ids = cur_input_ids[:, :max_len]
            cur_labels = cur_input_ids.clone()

            if self.mask_question_tokens:
                cur_assistant_masks = torch.tensor(temp["assistant_masks"][:max_len], dtype=torch.bool).unsqueeze(0)

                # EDIT: only mask last assistant's response tokens
                assistant_token_indices = torch.where(cur_assistant_masks[0] == 1)[0]
                if len(assistant_token_indices) > 0:
                    last_message_end_idx = assistant_token_indices[-1]
                    last_message_start_idx = last_message_end_idx
                    for idx in reversed(assistant_token_indices[:-1]):
                        if idx == last_message_start_idx - 1:
                            last_message_start_idx = idx
                        else:
                            break

                    # EDIT: Create a new mask for just the last assistant message
                    last_message_mask = torch.zeros_like(cur_assistant_masks)
                    last_message_mask[0, last_message_start_idx:last_message_end_idx + 1] = 1
                    cur_assistant_masks = last_message_mask

                    assert cur_labels.shape == cur_assistant_masks.shape, "Label and mask shapes do not match"
                    cur_labels = torch.where(cur_assistant_masks, cur_labels, self.IGNORE_TOKEN_ID)

            assert cur_input_ids.shape == cur_labels.shape, "Input and label shapes do not match"

            # padding
            if cur_input_ids.shape[1] < max_len:
                cur_input_ids = torch.cat([
                    cur_input_ids,
                    torch.full(
                        (cur_input_ids.shape[0], max_len - cur_input_ids.shape[1]),
                        self.PAD_TOKEN_ID,
                        dtype=cur_input_ids.dtype,
                        device=cur_input_ids.device
                    )
                ], dim=1)
                cur_labels = torch.cat([
                    cur_labels,
                    torch.full(
                        (cur_labels.shape[0], max_len - cur_labels.shape[1]),
                        self.IGNORE_TOKEN_ID,
                        dtype=cur_labels.dtype,
                        device=cur_labels.device
                    )
                ], dim=1)

            input_ids.append(cur_input_ids)
            labels.append(cur_labels)

        # sanity check
        assert total_image_tokens == len(flattened_images), "Number of image tokens does not match the number of images"

        input_ids = torch.cat(input_ids)
        labels = torch.cat(labels)

        return dict(
            **vision_inputs,
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.PAD_TOKEN_ID),
        )
zjysteven commented 2 months ago

Yes it looks good to me.