NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.51k stars 1.34k forks source link

Fine-tune GIT on a VQA dataset (TextVQA) #287

Closed marine-tk closed 1 year ago

marine-tk commented 1 year ago

Hi!

I know there is already a finetuned version of GIT on TextVQA on HuggingFace, but I am personally trying to finetune GIT model for a VQA task on TextVQA (and later VQAv2) to understand how the model works and what inputs exactly it is expecting.

I have read the paper introducing GIT and I understand that the model on HuggingFace might differ since the model card was not released by the authors of the paper but I still had some questions about the model on HuggingFace since I didn't succeed at finetuning GIT on TextVQA:

Does it mean that during the training, I should give as input_ids the tokenized question+answer concatenated, and as labels only the answer or is there something else to do so that the LM loss is only applied on the answer and the [EOS] tokens? I was not really sure how to interpret this sentence but I tried with different inputs and none of what I tried worked:

The loss is decreasing in both cases so I'm not sure that I give the right inputs to the model as I can't obtain correct answers like microsoft/git-base-textvqa

To add a bit more of context, for my training and preprocessing of the data, I inspired myself from your tutorial on how to finetune GIT for image captioning and tried to adjust it for a VQA task (thank you for all the tutorials!!!), so here is a snippet of my code:

Pre-processing

processor = AutoProcessor.from_pretrained("microsoft/git-base")

class TextVQADataset(torch.utils.data.Dataset):

    def __init__(self, question,text, images, processor,answer):
        self.images = images
        self.text = text        ### text is the q+a concatenated
        self.processor = processor
        self.question = question
        self.answer = answer

    def __len__(self):
        return len(self.text)

    def __getitem__(self, idx):
        text = self.text[idx]
        image = self.images[idx] 
        question = self.question[idx]
        answer = self.answer[idx]

        encoded_question = self.processor(text=question,padding="max_length",truncation=True,max_length=66,return_tensors="pt")
        encoded_image = self.processor(images=image,return_tensors="pt")
        encoded_text = self.processor(text=text,padding="max_length",truncation=True,max_length=66,return_tensors="pt")
        encoded_answer = self.processor(text=answer,padding="max_length",truncation=True,max_length=66,return_tensors="pt")

        # Remove batch dimension
        for k,v in encoded_question.items():
          encoded_question[k] = v.squeeze()
        for k,v in encoded_image.items():
          encoded_image[k] = v.squeeze()
        for k,v in encoded_text.items():
          encoded_text[k] = v.squeeze()
        for k,v in encoded_answer.items():
          encoded_answer[k] = v.squeeze()

        encoded_question["pixel_values"] = encoded_image["pixel_values"]
        encoded_question["input_ids_labels"] = encoded_text["input_ids"] 
        encoded_question["input_ids_answer"] = encoded_answer["input_ids"]
        encoded_question["attention_mask"] = encoded_text["attention_mask"]

        return encoded_question

train_dataset = TextVQADataset(question=train_ds["question"],text=train_ds["text"],images=train_ds["image"],processor=processor,answer=train_ds["answers"])
validation_dataset = TextVQADataset(question=validation_ds["question"],text=validation_ds["text"],images=validation_ds["image"],processor=processor,answer=validation_ds["answers"])

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(validation_dataset, shuffle=True, batch_size=batch_size)

Model

model = AutoModelForCausalLM.from_pretrained("microsoft/git-base")
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
model.to(device)

num_epochs=15
num_training_steps = num_epochs * len(train_dataloader)
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(1,num_epochs+1):
  for idx, batch in enumerate(train_dataloader):
    input_ids_labels = batch.pop("input_ids_labels").to(device)  ### it's the tokenized and concatenated q+a
    pixel_values = batch.pop("pixel_values").to(device)
    attention_mask = batch.pop("attention_mask").to(device)  ### attention_mask associated with input_ids_labels
    input_ids_answer = batch.pop("input_ids_answer").to(device) ### encoded answers

    outputs = model(input_ids=input_ids_labels,
                    pixel_values=pixel_values,
                    attention_mask=answer_mask,
                    labels=input_ids_answer)

    loss = outputs.loss
    list = [epoch,num_epochs, loss.item()]

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Progress bar
    progress_bar.update(1)
    progress_bar.set_description("[Epoch] Loss: [{}/{}] {}".format(*list))

Thank you in advance and have a good day!

NielsRogge commented 1 year ago

Hi,

Thanks for your interest in GIT :)

For visual question answering, the question and the ground-truth answer are concatenated as a new special caption during the fine-tuning, but the LM loss is only applied on the answer and the [EOS] tokens

This is a typical trick used when fine-tuning Transformer decoders, this can be achieved by setting the labels to -100 for all tokens for which you don't want to incur a loss (as -100 is the ignore_index of PyTorch's cross-entropy loss).

An example of how this can be done can be seen here in the Donut repo (Donut is also a model that can do VQA on images very similar to GIT). As can be seen, the labels are a copy of the input_ids, but then we make sure the model doesn't need to predict the prompt (like the question in case of VQA) by replacing the prompt tokens by the ignore_id which is set to -100. Note that each training example has a different amount of prompt tokens (as the question of each training example can have a different length), hence the Donut authors check when the prompt_end_token_id occurs in the sequence to know where the question ends and the answer starts.

marine-tk commented 1 year ago

Thank you for your answer! Wow, I understand better now, I will try this trick, thank you so much again!

Update: I just added this modification on my preprocessing and it seems to work fine now (I tried it on a smaller sample of TextVQA dataset)!