huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.01k stars 1.27k forks source link

Why are instructions not masked when performing VSFT for LLaVa? #1880

Closed shijian2001 closed 1 month ago

shijian2001 commented 3 months ago

I have some questions about the LLavaDataCollator in the vsft_llava.py:

https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py

 class LLavaDataCollator:
        def __init__(self, processor):
            self.processor = processor

        def __call__(self, examples):
            texts = []
            images = []
            for example in examples:
                if len(example["images"]) > 1:
                    raise ValueError("This collator only supports one image per example")
                messages = example["messages"]
                text = self.processor.tokenizer.apply_chat_template(
                    messages, tokenize=False, add_generation_prompt=False
                )
                texts.append(text)
                images.append(example["images"][0])

            batch = self.processor(texts, images, return_tensors="pt", padding=True)

            labels = batch["input_ids"].clone()
            if self.processor.tokenizer.pad_token_id is not None:
                labels[labels == self.processor.tokenizer.pad_token_id] = -100
            batch["labels"] = labels

            return batch

I noticed that you copied the input_id (image, question concatenated with answer) to the label, and then only set the label of the pad token to -100 (no loss will be calculated). However, as far as I understand SFT, only the loss of the answer part should be calculated, which means that we should also set the labels of all question parts to -100?

Looking forward to your reply!

shijian2001 commented 3 months ago

@qgallouedec When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part. Does trl provide a direct implementation for this? Can you give an example? Thanks!

shijian2001 commented 3 months ago

@qgallouedec Sorry to bother you, I would like to ask if SFTTrainer can directly calculate the loss of only the response part, and whether you have plans to implement a vsft script that only calculates the response loss. Thank you!

qgallouedec commented 3 months ago

Hi, sorry for the delay, I'm addressing the issues in order, and there are a lot these days.

only the loss of the answer part should be calculated

Can you justify this? In general, loss is calculated over the entire text input, including the prompt and the answer.

When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part.

I'm not sure about this. Have you tried it? It would be good to have some results to confirm or refute this statement.

Does trl provide a direct implementation for this?

A small modification of the data collator should be enough. Just set labels to -100 for the prompt part.

shijian2001 commented 3 months ago

In the implementation of the llava repository, the padding token and instruction token are all set to -100. For reference, see the preprocess_v1 function in https://github.com/haotian-liu/LLaVA/blob/main/llava/train/train.py

if has_image:
    round_len = len(tokenizer_image_token(rou, tokenizer))
    instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
    round_len = len(tokenizer(rou).input_ids)
    instruction_len = len(tokenizer(parts[0]).input_ids) - 2

if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
    round_len -= 1
    instruction_len -= 1

target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
shijian2001 commented 3 months ago

I roughly implemented this idea as follows:

class DataCollator:
    def __init__(self, processor, enable_mask_instructions: bool=False):
        self.processor = processor
        self.processor.tokenizer.model_max_length = 2048
        self.enable_mask_instructions = enable_mask_instructions
        self.IGNORE_INDEX = -100

    def _mask_padding_tokens(self, labels: torch.Tensor):
        """Only mask padding tokens"""
        pad_token_id = self.processor.tokenizer.pad_token_id
        labels[labels == pad_token_id] = self.IGNORE_INDEX
        return labels

    def _prepare_vsft_labels(self, labels: torch.Tensor):
        """Mask instructions and padding tokens"""

        # [Note] EOS token and assistant_token may be different for different chat_templates
        eos_token_id = self.processor.tokenizer.convert_tokens_to_ids("</s>")
        assistant_token_id = self.processor.tokenizer.encode("ASSISTANT:", add_special_tokens=False)

        batch_size, _ = labels.shape

        for i in range(batch_size):

            # Get positions of all eos tokens
            eos_positions = (labels[i] == eos_token_id).nonzero(as_tuple=True)[0]
            # Add 0 to eos_positions; Helpful for following loop
            eos_positions = torch.cat([torch.tensor([0], device=labels.device), eos_positions])

            # Consider the first special token <s>
            cur_len = 1
            labels[i, :cur_len] = self.IGNORE_INDEX

            for j in range(len(eos_positions) - 1):
                start = eos_positions[j]
                end = eos_positions[j+1]

                assistant_pos = None
                for k in range(start, end - len(assistant_token_id) + 1):
                    if torch.equal(labels[i, k:k+len(assistant_token_id)], torch.tensor(assistant_token_id, device=labels.device)):
                        assistant_pos = k
                        break

                if assistant_pos is not None:
                    labels[i, cur_len:assistant_pos + len(assistant_token_id)] = self.IGNORE_INDEX
                    cur_len = end + 1

        masked_labels = self._mask_padding_tokens(labels)

        return masked_labels

    def __call__(self, examples):
        texts = []
        images = []
        for example in examples:
            image = example["images"][0]
            messages = example["messages"]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text.strip())
            images.append(image)

        batch = self.processor(text=texts, images=images, return_tensors="pt", truncation=True, padding=True) # lauch truncated

        labels = batch["input_ids"].clone()
        if self.enable_mask_instructions:
            # Mask instructions and padding tokens
            mask_labels = self._prepare_vsft_labels(labels)
        else:
            # Only mask padding tokens
            mask_labels = self._mask_padding_tokens(labels)

        batch["labels"] = mask_labels

        return batch
qgallouedec commented 3 months ago

Thanks for the reference and for the piece of code which can certainly be useful. My position is to keep the sft example for vlm as it is (don't mask the instructions). If at some point we manage to prove that in the general case we get faster convergence or better results with instruction masking, then we'll modify the example along those lines. Feel free to feed this conversation if you find interesting results.

qgallouedec commented 1 month ago

Closing as this conversation has not received an update recently.