Closed KevinH48264 closed 2 months ago
On a high level I would imagine this would need a inference-dedicated collator which will be passed to the eval data loader, as like you mentioned the current collator uses ground-truth answers and does not encode any inference-time behaviors.
I wouldn't be available to work on this until late next week, but would be happy to discuss and provide necessary instructions.
Ah okay, looking at the inference dedicated collator for LLaVAInterleave, are these the major things I should probably edit? Or are there other things so that I can get an actual inference time generate / predict tokens? (Changes to original LLaVAInterleaveDataCollator are commented with "EDIT")
class LLaVAInterleaveInferenceDataCollator(BaseDataCollator):
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# images
vision_inputs = dict()
flattened_images: List[PIL.Image.Image] = [x for instance in instances for x in instance["images"]]
if len(flattened_images) > 0:
vision_inputs.update(**self.processor.image_processor(flattened_images, return_tensors="pt"))
num_images: List[int] = [len(instance["images"]) for instance in instances]
# texts
# the dataset implementation assume conversations are [user, assistant, user, assistant, ...]
system_prompts: List[Union[str, None]] = [instance["system_prompt"] for instance in instances]
conversations: List[List] = [instance["conversations"] for instance in instances]
max_len = self.tokenizer.model_max_length
total_image_tokens = 0
input_ids = []
labels = []
for cur_num_images, system_prompt, cur_convs in zip(num_images, system_prompts, conversations):
cur_num_image_tokens = 0
cur_input_ids = []
cur_labels = []
cur_text = []
if system_prompt is not None:
cur_text.append({
"role": "system",
"content": [{"type": "text", "text": system_prompt}]
})
for i, text in enumerate(cur_convs):
if i % 2 == 0:
num_image_tokens = len([m.start() for m in re.finditer("<image>", text)])
cur_num_image_tokens += num_image_tokens
total_image_tokens += num_image_tokens
# .strip(): whitespaces and newlines are handled by chat_template
text = text.replace("<image>", "").strip()
cur_text.append({
"role": "user",
"content": [{"type": "text", "text": text}] + \
[{"type": "image"}] * num_image_tokens
})
elif i != len(cur_convs) - 1: # EDIT - do not include last assistant message to allow for generation
cur_text.append({
"role": "assistant",
"content": [{"type": "text", "text": text}]
})
assert cur_num_image_tokens == cur_num_images, "Not all images were used"
temp = self.processor.apply_chat_template(
cur_text,
add_generation_prompt=True, # EDIT - allow for generation
tokenize=True,
return_assistant_tokens_mask=False, # EDIT - do not mask assistant tokens
return_dict=True,
return_tensors="pt",
truncation=False # the assistant tokens mask seems wrong when truncation is enabled
)
cur_input_ids = temp["input_ids"]
# manual truncation
if cur_input_ids.shape[1] > max_len:
cur_input_ids = cur_input_ids[:, :max_len]
cur_labels = cur_input_ids.clone()
if self.mask_question_tokens:
cur_assistant_masks = torch.tensor(temp["assistant_masks"][:max_len], dtype=torch.bool).unsqueeze(0)
# EDIT: only mask last assistant's response tokens
assistant_token_indices = torch.where(cur_assistant_masks[0] == 1)[0]
if len(assistant_token_indices) > 0:
last_message_end_idx = assistant_token_indices[-1]
last_message_start_idx = last_message_end_idx
for idx in reversed(assistant_token_indices[:-1]):
if idx == last_message_start_idx - 1:
last_message_start_idx = idx
else:
break
# EDIT: Create a new mask for just the last assistant message
last_message_mask = torch.zeros_like(cur_assistant_masks)
last_message_mask[0, last_message_start_idx:last_message_end_idx + 1] = 1
cur_assistant_masks = last_message_mask
assert cur_labels.shape == cur_assistant_masks.shape, "Label and mask shapes do not match"
cur_labels = torch.where(cur_assistant_masks, cur_labels, self.IGNORE_TOKEN_ID)
assert cur_input_ids.shape == cur_labels.shape, "Input and label shapes do not match"
# padding
if cur_input_ids.shape[1] < max_len:
cur_input_ids = torch.cat([
cur_input_ids,
torch.full(
(cur_input_ids.shape[0], max_len - cur_input_ids.shape[1]),
self.PAD_TOKEN_ID,
dtype=cur_input_ids.dtype,
device=cur_input_ids.device
)
], dim=1)
cur_labels = torch.cat([
cur_labels,
torch.full(
(cur_labels.shape[0], max_len - cur_labels.shape[1]),
self.IGNORE_TOKEN_ID,
dtype=cur_labels.dtype,
device=cur_labels.device
)
], dim=1)
input_ids.append(cur_input_ids)
labels.append(cur_labels)
# sanity check
assert total_image_tokens == len(flattened_images), "Number of image tokens does not match the number of images"
input_ids = torch.cat(input_ids)
labels = torch.cat(labels)
return dict(
**vision_inputs,
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.PAD_TOKEN_ID),
)
Yes it looks good to me.
During training, I'd like to use HuggingFace's compute_metrics in Trainer to evaluate for things such as inference time "Exact Match" accuracies, but it looks like data collator only supports generation with teacher forcing for now.
Currently without evaluating during training, I'd use the inference code with "add_generation_prompt" = True as such:
And then my compute_metrics function I would pass to Trainer is:
The main problem is that pred_str is not the typical expected output of just the assistant response like during inference time, but pred_ids seems to be of size (10, 7640). Decoded pred by printing pred_str above is pasted below.
Effectively, I believe it's because of the Data Collator which is also for evaluation predictions using teacher forcing where it conditions on ground truth previous tokens instead of generated tokens which makes sense for calculating evaluation loss, but doesn't reflect inference time behavior.
My main question is if there are any ideas for how I can change Data Collator and train.py in the code base such that it doesn't use teacher forcing for evaluation and would pass in the expected pred_ids to compute_metrics? I think if others end up using compute_metrics, they'd also want inference time behavior as well, so maybe the collator can change depending on if compute_metrics is passed? I'm not too sure right now.
Thanks again for the great work!
Printed pred:
Label: <something short like a caption for the image, which is what I expected pred_str to be>