Closed shijian2001 closed 1 month ago
@qgallouedec When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part. Does trl provide a direct implementation for this? Can you give an example? Thanks!
@qgallouedec Sorry to bother you, I would like to ask if SFTTrainer can directly calculate the loss of only the response part, and whether you have plans to implement a vsft script that only calculates the response loss. Thank you!
Hi, sorry for the delay, I'm addressing the issues in order, and there are a lot these days.
only the loss of the answer part should be calculated
Can you justify this? In general, loss is calculated over the entire text input, including the prompt and the answer.
When performing SFT on a VLM, it may be a better choice only to calculate the loss of the response part.
I'm not sure about this. Have you tried it? It would be good to have some results to confirm or refute this statement.
Does trl provide a direct implementation for this?
A small modification of the data collator should be enough. Just set labels to -100
for the prompt part.
In the implementation of the llava repository, the padding token and instruction token are all set to -100. For reference, see the preprocess_v1 function in https://github.com/haotian-liu/LLaVA/blob/main/llava/train/train.py
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14:
round_len -= 1
instruction_len -= 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
I roughly implemented this idea as follows:
class DataCollator:
def __init__(self, processor, enable_mask_instructions: bool=False):
self.processor = processor
self.processor.tokenizer.model_max_length = 2048
self.enable_mask_instructions = enable_mask_instructions
self.IGNORE_INDEX = -100
def _mask_padding_tokens(self, labels: torch.Tensor):
"""Only mask padding tokens"""
pad_token_id = self.processor.tokenizer.pad_token_id
labels[labels == pad_token_id] = self.IGNORE_INDEX
return labels
def _prepare_vsft_labels(self, labels: torch.Tensor):
"""Mask instructions and padding tokens"""
# [Note] EOS token and assistant_token may be different for different chat_templates
eos_token_id = self.processor.tokenizer.convert_tokens_to_ids("</s>")
assistant_token_id = self.processor.tokenizer.encode("ASSISTANT:", add_special_tokens=False)
batch_size, _ = labels.shape
for i in range(batch_size):
# Get positions of all eos tokens
eos_positions = (labels[i] == eos_token_id).nonzero(as_tuple=True)[0]
# Add 0 to eos_positions; Helpful for following loop
eos_positions = torch.cat([torch.tensor([0], device=labels.device), eos_positions])
# Consider the first special token <s>
cur_len = 1
labels[i, :cur_len] = self.IGNORE_INDEX
for j in range(len(eos_positions) - 1):
start = eos_positions[j]
end = eos_positions[j+1]
assistant_pos = None
for k in range(start, end - len(assistant_token_id) + 1):
if torch.equal(labels[i, k:k+len(assistant_token_id)], torch.tensor(assistant_token_id, device=labels.device)):
assistant_pos = k
break
if assistant_pos is not None:
labels[i, cur_len:assistant_pos + len(assistant_token_id)] = self.IGNORE_INDEX
cur_len = end + 1
masked_labels = self._mask_padding_tokens(labels)
return masked_labels
def __call__(self, examples):
texts = []
images = []
for example in examples:
image = example["images"][0]
messages = example["messages"]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
texts.append(text.strip())
images.append(image)
batch = self.processor(text=texts, images=images, return_tensors="pt", truncation=True, padding=True) # lauch truncated
labels = batch["input_ids"].clone()
if self.enable_mask_instructions:
# Mask instructions and padding tokens
mask_labels = self._prepare_vsft_labels(labels)
else:
# Only mask padding tokens
mask_labels = self._mask_padding_tokens(labels)
batch["labels"] = mask_labels
return batch
Thanks for the reference and for the piece of code which can certainly be useful. My position is to keep the sft example for vlm as it is (don't mask the instructions). If at some point we manage to prove that in the general case we get faster convergence or better results with instruction masking, then we'll modify the example along those lines. Feel free to feed this conversation if you find interesting results.
Closing as this conversation has not received an update recently.
I have some questions about the
LLavaDataCollator
in the vsft_llava.py:https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py
I noticed that you copied the input_id (image, question concatenated with answer) to the label, and then only set the label of the pad token to -100 (no loss will be calculated). However, as far as I understand SFT, only the loss of the answer part should be calculated, which means that we should also set the labels of all question parts to -100?
Looking forward to your reply!