siyan-zhao / prepacking

The source code of our work "Prepacking: A Simple Method for Fast Prefilling and Increased Throughput in Large Language Models"
https://arxiv.org/abs/2404.09529
56 stars 2 forks source link

AMAZING WORK! 4d mask support. #1

Open aldopareja opened 6 months ago

aldopareja commented 6 months ago

First of all, congrats on the amazing work!, I myself have been working on something like this for the last week but seeing this is such a relief, it'll just confirm that this idea works and there's even an implementation ready.

Hugging face recently merged this into their master. https://huggingface.co/blog/poedator/4d-masks

With that, there should be no need for the custom model and it should just work. However, the input attention masks would need to be 4d, and yours are 2d. I'm sure mathematically is the same, I just need to understand how to build this.

here is an example of how to use 4d masks in hf.

the way you call the model is quite similar as well:

    packed_outputs = custom_model(
        input_ids=packed_tokens.to(device),
        attention_mask=independent_mask.to(device),
        position_ids=restart_positions.to(device),
        return_dict=True,
        output_hidden_states=True,
    )

Do you have any pointers on how could I reuse your data processor directly with huggingface 4d masks so I don't need a custom model and can train any model that supports this API in hf?

aldopareja commented 6 months ago

I did implement a naive version of 4d masking using a custom collating function that would put everything in the same sentence and get independent causal masks.

Basically did this:

image

with this function:

    def pad_collate_fn_4d_mask(batch):
        all_inputs = []
        all_labels = []
        all_position_ids = []

        lens = np.array([len(item["input_ids"]) for item in batch])
        cumsum_lens = np.cumsum(lens)
        valid_up_to = int((cumsum_lens<max_batch_len).sum())
        total_len = lens[:valid_up_to].sum()
        print(f"\033[96m total batch len: {total_len} -- rank: {rank}\033[0m")

        attention_masks = torch.zeros((1,1, total_len, total_len), dtype=torch.bool)
        cur_len = 0
        for i,item in enumerate(batch[:valid_up_to]):
            input_ids = item["input_ids"]
            len_ids = len(input_ids)

            labels = item["labels"]
            position_ids = torch.arange(len_ids)

            all_inputs.extend(input_ids)
            all_labels.extend(labels)
            all_position_ids.extend(position_ids)

            attention_masks[:,:,cur_len:cur_len+len_ids, cur_len:cur_len+len_ids] = torch.tril(torch.ones((len_ids, len_ids), dtype=torch.bool))
            cur_len += len_ids

        return {
            "input_ids": torch.tensor(all_inputs).unsqueeze(0),
            "labels": torch.tensor(all_labels).unsqueeze(0),
            "attention_mask": attention_masks,
            "position_ids": torch.tensor(all_position_ids).unsqueeze(0),
        }

So I accumulated samples up to a maximum number of tokens and generated independent causal masks for each in the attention mask.

This DOES NOT WORK though. Because for each independent sample, all the other tokens in the large sentence become analogous to padding, and in a transformer, all operations happen on all tokens, attention just ignores computations on all other values, but the compute is still wasted. So this ended up being much slower than naive sampling.

siyan-zhao commented 6 months ago

Hi Aldo,

Thank you for your interest! We agree and will add 4D mask support soon.

Regarding your second comment, are you profilling the training time? Our method aims to enhance throughput for generation (with metrics such as prefilling time/TTFT). The figure you sent seems to be from the multipack doc, which aims to improve the training throughput instead.

The time savings come from replacing computations on paddings with actual sentences. Even though it will treat the other sentences in the same row as padding as you mentioned, the overall computations are still reduced. Please let me know if I understand your questions and if this addresses them.

aldopareja commented 6 months ago

Yes overall computations should be reduced if you concatenate vertically instead of naively putting everything in a single sentence!. That's probably why I got such bad throughput myself.

I recently included this instead: https://github.com/imoneoi/multipack_sampler

It greatly reduces padding, but it biases batches distributions wrt to random sampling (although Im still testing this claim)

Also included this but it requires modifying the forward pass non-trivially: https://huggingface.co/blog/mayank-mishra/padding-free-transformer

And yes I'm talking about training efficiency, although inference efficiency should be almost analogous to this.

In any case any thoughts greatly appreciated, otherwise my concerns are addressed!