facebookresearch / chameleon

Repository for Meta Chameleon, a mixed-modal early-fusion foundation model from FAIR.
https://arxiv.org/abs/2405.09818
Other
1.83k stars 112 forks source link

Finetuning input / target formatting #70

Open AetherPrior opened 4 days ago

AetherPrior commented 4 days ago

Hi, I've a question on data collation for finetuning. I have some input questions and some targets, and wish to know if I need to include the inputs as part of my labels during causal finetuning. Specifically, I've defined my collation function as follows:

def chameleon_collate_fn(batch):
    # Extract the images and questions
    images = [ex['image'] for ex in batch]
    questions = ["<image>"+ex['question'] for ex in batch]
    labels = ["<image>"+ex['question'] + " " + ex['answer'] for ex in batch]

    # Process the batch using the processor
    batch_inputs = processor(images=images, text=questions, return_tensors="pt", padding=True)

    labels = processor(images=images, text=labels, return_tensors="pt", padding=True).input_ids # feels like labels should be the inputs + answer themselves? 

    # mask out pad tokens
    labels = labels.masked_fill(labels == processor.tokenizer.pad_token_id, -100)
    # mask the input from the labels
    labels[:, :len(batch_inputs["input_ids"])] = -100

    batch_inputs["labels"] = labels

    # Move inputs and labels to the appropriate device
    batch_inputs = {key: val.to('cuda') for key, val in batch_inputs.items()}

    return batch_inputs

However, when I pass these to the model call in the training loop:


for epoch in range(num_epochs):
    step_counter = 0
    for batch in tqdm(train_loader):
        inputs = batch
        model.to('cuda')
        # Forward pass
        outputs = model(**inputs)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # wandb log steps
        wandb.log({"global_step": step_counter})

        # Log loss
        wandb.log({"loss": loss.item()})
        step_counter += args.batch_size

    wandb.log({"epoch": epoch + 1})

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

I get a ValueError:

ValueError: Expected input batch_size (2068) to match target batch_size (2070).

My batch size is 2, and my input shape is [2,1035] with my targets [2,1036] (one extra generation token for a numerical answer), so I'm not sure what's the issue here. Could someone help? Thanks!

AetherPrior commented 4 days ago

I figured this out by changing the data collation pipeline to have the same input and output:

def chameleon_collate_fn(batch):
    # Extract the images and questions
    images = [ex['image'] for ex in batch]
    labels = ["<image>"+ex['question'] + " " + ex['answer'] for ex in batch]

    # Process the batch using the processor
    batch_inputs = processor(images=images, text=labels, return_tensors="pt", padding=True)

    labels = processor(images=images, text=labels, return_tensors="pt", padding=True).input_ids.clone() # feels like labels should be the inputs + answer themselves? 

    # mask out pad tokens
    labels = labels.masked_fill(labels == processor.tokenizer.pad_token_id, -100)
    # mask the input from the labels
    labels[:, :len(batch_inputs["input_ids"])] = -100

    batch_inputs["labels"] = labels

    # Move inputs and labels to the appropriate device
    batch_inputs = {key: val.to('cuda') for key, val in batch_inputs.items()}

    return batch_inputs

I'm still not too clear about whether this is correct so any comments on this will be highly appreciated!
Thank you so much.