Open qibao77 opened 1 week ago
@qibao77 in our implementation, we changed 2 things:
attention_mask
to represent the packing, marking the start and end of each packed input. Assuming that the max_input_length is 10 and we have 2 data points:
input_ids1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1, 0, 0, 0, 0, 0, 0, 0]
input_ids2 = [4,5,6,7,8, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1,1, 1, 0, 0, 0, 0, 0, 0]
When we pack 2 data points into 1 data point:input_ids = [1,2,3,4,5,6,7,8, 0, 0]; attention_mask=[1, 1, 1, 2, 2, 2, 2, 2, 0, 0]
. Here the attention_mask is used to mark the boundary of individual data points, 1 for data point 1 and 2 for data point 2 and 0 for padding (the same as without packing)
Here, assume that padding_token_id=0.
Without packing, we have 2 data points:_get_unpad_data
) doesn't work as it was implemented to only accept 0 and 1, so we overwrite function: _get_unpad_data
to accept the extended attention_mask
@qibao77 in our implementation, we changed 2 things:
First we extend the format of
attention_mask
to represent the packing, marking the start and end of each packed input. Assuming that the max_input_length is 10 and we have 2 data points:
input_ids1 = [1, 2, 3, 0, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1, 0, 0, 0, 0, 0, 0, 0]
input_ids2 = [4,5,6,7,8, 0, 0, 0, 0, 0, 0]; attention_mask = [1,1,1,1, 1, 0, 0, 0, 0, 0, 0]
When we pack 2 data points into 1 data point:input_ids = [1,2,3,4,5,6,7,8, 0, 0]; attention_mask=[1, 1, 1, 2, 2, 2, 2, 2, 0, 0]
. Here the attention_mask is used to mark the boundary of individual data points, 1 for data point 1 and 2 for data point 2 and 0 for padding (the same as without packing) Here, assume thatpadding_token_id=0.
Without packing, we have 2 data points:- With the extended attention_mask, the current code (of function:
_get_unpad_data
) doesn't work as it was implemented to only accept 0 and 1, so we overwrite function:_get_unpad_data
to accept the extendedattention_mask
Thank you for your reply! I want to add this feature to my pretraining code, like llama3, but I found that there is no change in the loss compared to naive packing, is there any advise?
What do you mean by no change in the loss ? you mean: loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?
Yes,in my experiment, loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)), and I have checked that the "_get_unpad_data" function was replaced correctly.
@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?
@qibao77 were you pre-training or fine-tuning?
curious, was the loss exactly matching step by step, or was that much later?
@qibao77 Can you share your experimental code showing loss(naive_packing(a, b)) == loss (packing_without_cross_contamination(a, b)) ?
For loss (packing_without_cross_contamination(a, b)) , the code is shown as follows: ...
monkey_patch_packing_for_model(self.local_dir)
self.gpt = LlamaForCausalLM.from_pretrained(
self.local_dir, config=self.hf_config, trust_remote_code=True, revision='main', offload_state_dict=True,attn_implementation="flash_attention_2"
)
...
attention_mask = generate_attention_mask(input_ids,special_token_end=self.tokenizer.eos_token_id,pad_token_id=self.tokenizer.pad_token_id)
model_out = self.gpt(
input_ids=input_ids, attention_mask=attention_mask, labels=labels
)
For the definetion of generate_attention_mask:
def generate_attention_mask(input_ids, special_token_end=3, pad_token_id=0):
batch_size, seq_len = input_ids.shape
mask = torch.zeros_like(input_ids)
for i in range(batch_size):
current_label = 1
for j in range(seq_len):
if input_ids[i, j] == special_token_end:
mask[i, j] = current_label
current_label += 1
elif input_ids[i, j] == pad_token_id: break
else:
mask[i, j] = current_label
return mask
For loss(naive_packing(a, b)): the function generate_attention_mask is not used, and the value of attention_mask is 1 ,excepted padding part.
@qibao77 were you pre-training or fine-tuning?
curious, was the loss exactly matching step by step, or was that much later?
pre-training, matching step by step
@qibao77 it's unclear how you could be matching step by step if the attention masks are different.
@qibao77 actually you can run this script to see that the Naive packing will give a different loss compared with Packing without cross-contamination. In this script, assume that there are 2 data point: a = [1,2,3] b = [4, 5, 6, 7, 8] I compare the loss of: 1) loss(a) + loss(b) 2) loss(naive_pack(a, b)) 3) loss(packing_without_cross_contamination(a, b))
The result is: 1) loss(a) + loss(b) = 44.141 2) loss(naive_pack(a, b)) = 37.55 3) loss(packing_without_cross_contamination(a, b)) = 44.17
You see that Naive packing is problematic, right ?
from transformers import AutoModelForCausalLM, AutoTokenizer
import monkey_patch_packing
import torch
def main():
# pad_token = 0
# max_length = 10
pretrained_path = "meta-llama/Meta-Llama-3.1-8B"
input_ids1 = [1, 2, 3] + [0 for _ in range(7)]
labels1 = [1, 2, 3] + [-100 for _ in range(7)]
attention1 = [1, 1, 1] + [0 for _ in range(7)]
input_ids2 = [4, 5, 6, 7, 8] + [0 for _ in range(5)]
labels2 = [4, 5, 6, 7, 8] + [-100 for _ in range(5)]
attention2 = [1, 1, 1, 1, 1] + [0 for _ in range(5)]
# packing
packed_inputs = [1,2,3,4,5,6,7,8, 0, 0]
# note here that 4 is the first token so will not be included for computing loss
packed_labels = [1,2,3,-100,5,6,7,8, -100, -100]
naive_attention = [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
correct_packed_attention = [1, 1, 1, 2, 2, 2, 2, 2, 0, 0]
assert len(input_ids1) == len(input_ids2) == len(attention1) == len(attention2) == len(naive_attention) == len(correct_packed_attention) == len(packed_inputs)
# Load model without monkey-patching
model = AutoModelForCausalLM.from_pretrained(
pretrained_path,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
trust_remote_code=True
)
# loss without using packing
loss1, num_tok1 = compute_loss(model, input_ids1, attention1, labels1)
loss2, num_tok2 = compute_loss(model, input_ids2, attention2, labels2)
total_original_loss = loss1 + loss2
total_original_num_tokens = num_tok1 + num_tok2
print(f"total original loss: {total_original_loss}; total_original_num_tokens={total_original_num_tokens}")
# loss with native packing
naive_loss, naive_num_tok = compute_loss(model, packed_inputs, naive_attention, packed_labels)
print(f"naive loss: {naive_loss}; num_token: {naive_num_tok}")
# loss with packing without cross-contamination
# need to reload using monkey-patched code
monkey_patch_packing.monkey_patch_packing_for_model(pretrained_path)
model = AutoModelForCausalLM.from_pretrained(
pretrained_path,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
trust_remote_code=True
)
correct_loss, correcte_num_tok = compute_loss(model, packed_inputs, correct_packed_attention, packed_labels)
print(f"correct_loss: {correct_loss}; num_token={correcte_num_tok}")
def compute_loss(model, input_ids, attention, labels):
inputs = {
"input_ids": torch.tensor([input_ids]).to(model.device),
"labels": torch.tensor([labels]).to(model.device),
"attention_mask": torch.tensor([attention]).to(model.device)
}
total_num_loss_tokens = 0
total_loss = 0
with torch.no_grad():
avg_loss = model.forward(**inputs).loss.item()
# compute number of tokens used for computing loss
labels = inputs["labels"]
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
ignore_count = (shift_labels == -100).sum()
num_tokens = shift_labels.size(0) - ignore_count
total_num_loss_tokens += num_tokens.item()
total_loss += avg_loss * num_tokens.item()
return total_loss, total_num_loss_tokens
if __name__ == "__main__":
main()
Thanks for your good job! Why this operation (" overwriting the function: _get_unpad_data with a monkey-patched function") can implement the feature of packing without cross-contamination attention? Can you explain more details or give some reference to me? Thank you very much!