Luodian / Otter

🦦 Otter, a multi-modal model based on OpenFlamingo (open-sourced version of DeepMind's Flamingo), trained on MIMIC-IT and showcasing improved instruction-following and in-context learning ability.
https://otter-ntu.github.io/
MIT License
3.52k stars 241 forks source link

how to generate labels for FUYU #305

Open aamir-gmail opened 8 months ago

aamir-gmail commented 8 months ago

In models / fuyu processing_fuyu.py , the method get labels , what is the purpose of special_token_id and how do I get it. For example my input ids look like this. " Extract text from this image " , using fuyo processor I pass in the image and text and get input_ids , I am not too sure how to get labels from input_ids using the above method.

Luodian commented 5 months ago

The special_token_id is from Fuyu's design, it's a \x04 that use to separate Questions and Answers.

(if I remember correctly) Fuyu's template is: "{question}\n\x04{answer}\x04".

Our template is "User:{question} Assistant:\x04{answer}\x04".

We also use it to locate the answer's position since we need to mask the {answer} during training.

The code is here~

# src/otter_ai/models/fuyu/processing_fuyu.py
def get_labels(self, input_ids, special_token_id, masking_number=-100):
    # Initialize labels tensor filled with masking_number
    labels = torch.full_like(input_ids, masking_number)

    # Iterate through each sequence in the batch
    for i in range(input_ids.shape[0]):
        seq = input_ids[i]
        # Find the indices of the special_token_id
        indices = (seq == special_token_id).nonzero(as_tuple=False).squeeze()
        # Pair the indices and unmask the tokens between each pairt
        paired_indices = indices.reshape(-1, 2)
        for start, end in paired_indices:
            labels[i, start + 1 : end + 1] = seq[start + 1 : end + 1]

    return labels