MeetKai / functionary

Chat language model that can use tools and interpret the results
MIT License
1.36k stars 107 forks source link

Explain About Packing Inputs Without Cross-Contamination Attention #265

Open qibao77 opened 1 week ago

qibao77 commented 1 week ago

Thanks for your good job! Why this operation (" overwriting the function: _get_unpad_data with a monkey-patched function") can implement the feature of packing without cross-contamination attention? Can you explain more details or give some reference to me? Thank you very much!

khai-meetkai commented 1 week ago

@qibao77 in our implementation, we changed 2 things:

qibao77 commented 1 week ago

@qibao77 in our implementation, we changed 2 things:

  • First we extend the format of attention_mask to represent the packing, marking the start and end of each packed input. Assuming that the max_input_length is 10 and we have 2 data points:

    • input_ids1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1, 0, 0, 0, 0, 0, 0, 0]
    • input_ids2 = [4,5,6,7,8, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1,1, 1, 0, 0, 0, 0, 0, 0] When we pack 2 data points into 1 data point:
    • input_ids = [1,2,3,4,5,6,7,8, 0, 0]; attention_mask=[1, 1, 1, 2, 2, 2, 2, 2, 0, 0]. Here the attention_mask is used to mark the boundary of individual data points, 1 for data point 1 and 2 for data point 2 and 0 for padding (the same as without packing) Here, assume that padding_token_id=0. Without packing, we have 2 data points:
  • With the extended attention_mask, the current code (of function: _get_unpad_data) doesn't work as it was implemented to only accept 0 and 1, so we overwrite function: _get_unpad_data to accept the extended attention_mask

Thank you for your reply! I want to add this feature to my pretraining code, like llama3, but I found that there is no change in the loss compared to naive packing, is there any advise?

khai-meetkai commented 1 week ago

What do you mean by no change in the loss ? you mean: loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

qibao77 commented 1 week ago

Yes,in my experiment, loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)), and I have checked that the "_get_unpad_data" function was replaced correctly.

khai-meetkai commented 1 week ago

@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

vgoklani commented 1 week ago

@qibao77 were you pre-training or fine-tuning?

curious, was the loss exactly matching step by step, or was that much later?

qibao77 commented 6 days ago

@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?

For loss (packing_without_cross_contamination(a, b)) , the code is shown as follows: ...

monkey_patch_packing_for_model(self.local_dir)
self.gpt = LlamaForCausalLM.from_pretrained(
                            self.local_dir, config=self.hf_config, trust_remote_code=True, revision='main', offload_state_dict=True,attn_implementation="flash_attention_2"
                        )

...

attention_mask = generate_attention_mask(input_ids,special_token_end=self.tokenizer.eos_token_id,pad_token_id=self.tokenizer.pad_token_id)
model_out = self.gpt(
            input_ids=input_ids, attention_mask=attention_mask, labels=labels
        )

For the definetion of generate_attention_mask:

def generate_attention_mask(input_ids, special_token_end=3, pad_token_id=0):
    batch_size, seq_len = input_ids.shape
    mask = torch.zeros_like(input_ids)
    for i in range(batch_size):
        current_label = 1
        for j in range(seq_len):
            if input_ids[i, j] == special_token_end:
                mask[i, j] = current_label
                current_label += 1
            elif input_ids[i, j] == pad_token_id: break 
            else:
                mask[i, j] = current_label

    return mask

For loss(naive_packing(a, b)): the function generate_attention_mask is not used, and the value of attention_mask is 1 ,excepted padding part.

qibao77 commented 6 days ago

@qibao77 were you pre-training or fine-tuning?

curious, was the loss exactly matching step by step, or was that much later?

pre-training, matching step by step

vgoklani commented 6 days ago

@qibao77 it's unclear how you could be matching step by step if the attention masks are different.

khai-meetkai commented 4 days ago

@qibao77 actually you can run this script to see that the Naive packing will give a different loss compared with Packing without cross-contamination. In this script, assume that there are 2 data point: a = [1,2,3] b = [4, 5, 6, 7, 8] I compare the loss of: 1) loss(a) + loss(b) 2) loss(naive_pack(a, b)) 3) loss(packing_without_cross_contamination(a, b))

The result is: 1) loss(a) + loss(b) = 44.141 2) loss(naive_pack(a, b)) = 37.55 3) loss(packing_without_cross_contamination(a, b)) = 44.17

You see that Naive packing is problematic, right ?

from transformers import AutoModelForCausalLM, AutoTokenizer
import monkey_patch_packing 
import torch

def main():
    # pad_token = 0
    # max_length = 10
    pretrained_path = "meta-llama/Meta-Llama-3.1-8B"
    input_ids1 = [1, 2, 3] + [0 for _ in range(7)]
    labels1 = [1, 2, 3] + [-100 for _ in range(7)]
    attention1 = [1, 1, 1] + [0 for _ in range(7)]

    input_ids2 = [4, 5, 6, 7, 8] + [0 for _ in range(5)]
    labels2 = [4, 5, 6, 7, 8] + [-100 for _ in range(5)]
    attention2 = [1, 1, 1, 1, 1] + [0 for _ in range(5)]
    # packing
    packed_inputs = [1,2,3,4,5,6,7,8, 0, 0]
    # note here that 4 is the first token so will not be included for computing loss
    packed_labels = [1,2,3,-100,5,6,7,8, -100, -100]

    naive_attention = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
    correct_packed_attention = [1, 1, 1, 2, 2, 2, 2, 2, 0, 0]

    assert len(input_ids1) == len(input_ids2) == len(attention1) == len(attention2) == len(naive_attention) == len(correct_packed_attention) == len(packed_inputs)

    # Load model without monkey-patching
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        trust_remote_code=True
    )
    # loss without using packing
    loss1, num_tok1 = compute_loss(model, input_ids1, attention1, labels1)
    loss2, num_tok2 = compute_loss(model, input_ids2, attention2, labels2)
    total_original_loss = loss1 + loss2 
    total_original_num_tokens = num_tok1 + num_tok2
    print(f"total original loss: {total_original_loss}; total_original_num_tokens={total_original_num_tokens}")    
    # loss with native packing
    naive_loss, naive_num_tok = compute_loss(model, packed_inputs, naive_attention, packed_labels)
    print(f"naive loss: {naive_loss}; num_token: {naive_num_tok}")

    # loss with packing without cross-contamination
    # need to reload using monkey-patched code
    monkey_patch_packing.monkey_patch_packing_for_model(pretrained_path)
    model = AutoModelForCausalLM.from_pretrained(
        pretrained_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        attn_implementation="flash_attention_2",
        trust_remote_code=True
    )

    correct_loss, correcte_num_tok = compute_loss(model, packed_inputs, correct_packed_attention, packed_labels)
    print(f"correct_loss: {correct_loss}; num_token={correcte_num_tok}")

def compute_loss(model, input_ids, attention, labels):    
    inputs = {
        "input_ids": torch.tensor([input_ids]).to(model.device),
        "labels": torch.tensor([labels]).to(model.device),
        "attention_mask": torch.tensor([attention]).to(model.device)
    }
    total_num_loss_tokens = 0
    total_loss = 0
    with torch.no_grad():
        avg_loss = model.forward(**inputs).loss.item()
        # compute number of tokens used for computing loss
        labels = inputs["labels"]
        shift_labels = labels[..., 1:].contiguous()
        shift_labels = shift_labels.view(-1)
        ignore_count = (shift_labels == -100).sum()
        num_tokens = shift_labels.size(0) - ignore_count

        total_num_loss_tokens += num_tokens.item()
        total_loss += avg_loss * num_tokens.item()
    return total_loss, total_num_loss_tokens

if __name__ == "__main__":
    main()