pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.38k stars 446 forks source link

[RFC] Long context fine tuning in torchtune #1244

Open felipemello1 opened 3 months ago

felipemello1 commented 3 months ago

Goal


This RFC aims to define which type of solution to prioritize so we can bridge this gap.

What number do we hope to achieve, with which model and which GPU?

There are three ways I propose we think about it:

Sanity check - bsz, num_steps, memory required

Using llama2-7b (~half of the memory llama3-8b needs):

LongLora (https://arxiv.org/pdf/2309.12307) finetunes for 1000 steps (bsz=64) using 8x100, DDP,

image

YaRN paper (https://arxiv.org/pdf/2309.00071) does it for 400 steps (bsz=64, seq_len=64k, +200 steps for seq_len=128k)

Ring attention and Memory Efficient Attention (https://arxiv.org/pdf/2310.01889)

image

Why


image Source: https://www.reddit.com/r/LocalLLaMA/comments/18t63ub/is_context_length_32k_actually_useful_to_you/

Below are some use cases, but an argument we can make is that supporting long-context will enable new cases, even if they don't exist today.

Possible data flavors:

Possible use cases:

What about RAG?

RAG works in some scenarios and is cheap, but:

However, long context and RAG can be orthogonal approaches: with long context, people can return longer documents, instead of small paragraphs. In other words, long context enables new RAG capabilities.

Our current snapshot:


Currently we support at most ~24k with llama 8b, on A100-80GB, for LoRa, QLoRA, FFT with or without FSDP. After some memory optimizations, we can probably increase it by a little, but probably not enough for 32k.

image

You may think: Shouldn’t QLoRA support much more than FFT? At higher context lengths, the activations consume most of the memory.

Our levers


1) Inference solutions, such as:

(1) is not relevant, since our focus is finetuning. (2) not widely adopted and it is risky to add new modules that change its inference behavior. (3) We already support QLoRA, 8bitoptmizers. The issue is the quadratic growth of activation, and such methods don't tackle it. (4) We rely on pytorch for it. This leaves us with (5) and (6).

LongLora:

image

Pros:

Cons:

Pseudocode image

Ring Attention

Pseudocode

def naive_soft_max(x):
    return x.exp() / x.exp().sum()

# done in vanilla attention
target = naive_softmax(x) 

# softmax over chunks
x1,x2 = torch.chunk(x,2)
softmax1 = naive_softmax(x1)
softmax2 = naive_softmax(x2)

# incremental softmax
sum_exp_1 = x1.exp().sum()
sum_exp_2 = x2.exp().sum()

softmax1_corrected = softmax1*sum_exp_1 / (sum_exp_1 + sum_exp_2)
softmax2_corrected = softmax2*sum_exp_2 / (sum_exp_1 + sum_exp_2)

softmax_combined = torch.cat([softmax1_corrected, softmax2_corrected])
torch.all_close(target, softmax_combined)

image Source: https://www.youtube.com/watch?v=ws7angQYIxI

Pros:

Cons:

Code pointers: Pytorch: https://github.com/pytorch/pytorch/pull/129515 Torchtitan: https://github.com/pytorch/torchtitan/pull/433

Current constraints of pytorch implementation that they are actively working on:

Suggestion

My suggestion is to focus on ring attention experimental implementation by pytorch, since it scales well and will be supported by them. If their CP implementation is optimal, then 8xA100 would allow us to get to 8x24k = 192k context length. Most likely LongLora wouldn't be enough to reach 64k with llama3.

Timeline


According to our PyTorch PoC, the code to implement context parallelism is short, and can be done in a week. Therefore, I propose the following:

Week July 29th - Formalize script to, given a model, find max_seq_length for different settings, so we can test at 16GB, 24GB, 40GB, 80GB Week Aug 5th - Understand pytorch CP implementation Week Aug 11th - Proof-of-concept for llama3.1 8b, stress test max seq len Week Aug 19th - Propose and discuss implementation in torchtune Week Aug 26th - Evaluation on key long context datasets Week Sept 2nd - Optimize + unit test (maybe do it before eval?) Week Sept 9th - Implement for the rest of the models Week Sept 16th - Tutorials + best practices

The timeline should look similar for LongLora. There is a risk in focusing too much on llama 3.1 and having something that needs refactoring to work for other models. We should try to minimize this risk and make it model independent, but prioritizing a family of models will allow us to iterate faster, learn and make adjustments later.

Evaluation


image source: https://arxiv.org/pdf/2404.02060

Therefore, for evaluating our trained models I propose three strategies (text only):

As a starter, we can use LongAlpaca (https://huggingface.co/datasets/Yukang/LongAlpaca-12k), since it was also used by LongLora. It contains 3k short examples and 9k long examples from the Alpaca dasetet, ranging from 35 characters to 191k.

This strategy can be improved during the evaluation focus time.

joecummings commented 3 months ago

Same as above, but aim for 64k, and rely on RoPE scaling to extend context further

This confuses me as I consider RoPE scaling and these techniques to go hand-in-hand.

felipemello1 commented 3 months ago

right! This is more about our "success criteria". Do we want to finetune 128k max, or would we be comfortable to get to 64k max, and 128k can be achieved through scaling.

joecummings commented 3 months ago

reddit.com/r/LocalLLaMA/comments/18t63ub/is_context_length_32k_actually_useful_to_you

This is awesome motivation :)

joecummings commented 3 months ago

right! This is more about our "success criteria". Do we want to finetune 128k max, or would we be comfortable to get to 64k max, and 128k can be achieved through scaling.

Ahhh I see, I see. Personally, I think the latter is what I consider success.

joecummings commented 3 months ago

Albania Wikipedia Page: 43k tokens

I was just thinking about how I desperately want to finetune a model on the Albania Wikipedia page...

joecummings commented 3 months ago

Maybe a dumb question, but Ring Attention really only helps us improve context length based on num of GPUs...so a single GPU QLoRA would never be able to get to 128k under this plan. Is that correct? Are we just saying we don't have the BW to take on single GPU long context? If so, I want this to be explicit and called out here.

joecummings commented 3 months ago

How does this RFC (#1183) on RoPE scaling fit in here? Again, in my mind they are related, but it doesn't seem like this is accounted for in your work timeline.

felipemello1 commented 3 months ago

so a single GPU QLoRA would never be able to get to 128k under this plan. Is that correct?

correct, probably for the same reason we don't have 405b for a single GPU. It is possible to do long(er) context in a single GPU using the same ring-attention concept, but in series (you chunk QKV, and compute it in blocks in a for loop). It is the "memory efficient attention" https://arxiv.org/abs/2112.05682

implementation: https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/d54f391370ecbf843a871f0e260425d076995550/memory_efficient_attention_pytorch/memory_efficient_attention.py#L119

However, it is quadratically slower, and most likely still not enough to get to 128k.

felipemello1 commented 3 months ago

How does this RFC (#1183) on RoPE scaling fit in here? Again, in my mind they are related, but it doesn't seem like this is accounted for in your work timeline.

Good question! Since we released llama 3.1 that has support for 128k, I thought that this was a solved problem, and that RFC was more about design discussion. Are you still going to implement it? If so, would it be ready before the timeline: "Week Sept 9th - Implement for the rest of the models"

ebsmothers commented 3 months ago

Thanks for the detailed RFC! Overall I agree with the proposal to start trying out context parallel support using the latest APIs from distributed. I do have two main comments on a first pass (though I reserve the right to come back later with more 😅 ).

(1) Sequencing: I would like to get to "Proof-of-concept for llama3.1 8b, stress test max seq len" as fast as possible. This is most likely a prerequisite for any asks we have to distributed or any other memory-focused work we need to do. So I would suggest frontloading this portion of the work, imo the script and design discussions are moot until we have that. Even understanding the distributed CP implementation can be done in parallel

(2) We discussed this offline a bit already, but I would love to see an effort to increase the max context length on a single device as well. 24k is not great imo so we should explore some of the other techniques you described (along with other memory-saving techniques) in parallel. This is not necessarily something you need to consider as part of the context parallelism effort so nbd for this particular RFC, just want to call it out.

felipemello1 commented 3 months ago

So I would suggest frontloading this portion of the work

sounds good, I changed it to go before the design

RdoubleA commented 3 months ago

So is my understanding correct: for long context, ring attention / CP is the most promising but will require X number of GPUs, and if you want to use a single GPU you need to rely on RoPE scaling?

Also, do you have an idea of how ring attention will be called in torchtune? Does it happen under the hood with SDPA if we're using DTensors (based on my brief skim through the code snippets)?

felipemello1 commented 3 months ago

for long context, ring attention / CP is the most promising but will require X number of GPUs. Whatever context length you can fit in one gpu, you can make it times n_gpus.

Correct

if you want to use a single GPU you need to rely on RoPE scaling

It doesnt fit in memory. Its like if you wanted to fit 405b in a single GPU. Its just too much, unless you do a lot of cpu/gpu wizardry, but then its slow

Also, do you have an idea of how ring attention will be called in torchtune? Does it happen under the hood with SDPA if we're using DTensors (based on my brief skim through the code snippets)?

I will be able to answer it better next week

@RdoubleA

thusinh1969 commented 1 month ago

Any move in here ? We truly need long context says 128k. Any enhancement would be highly appreciated. We tent to use LlaMA-3.1-8B

Thanks, Steve

felipemello1 commented 1 month ago

@thusinh1969, the library is much more memory efficient now. If you use all the memory flags, you should be able to get to it depending on how much memory and gpus you have.

keep the loss as chunked cross entropy use compile=True enable_activation_checkpoint=True tokenizer.max_seq_len = your_seq_len dataset.packed=True # makes it faster

then if you training lora, you can use offloading. I need to implement it for the full finetune recipes enable_activation_offloading=True

if you are doing distributed training, you can also use the fsdp cpu offloading option in the config

krammnic commented 1 month ago

Is this still actual? I don't see any related PRs, but things that are proposed here are pretty reasonable.

felipemello1 commented 1 month ago

hey @krammnic , not really. We ended up going down a path of optimization, implemented activation offloading, chunked cross entropy and better compile. This saved tons of memory, allowing larger context. In the future we may come back to parallelism.

lulmer commented 1 day ago

Hey @felipemello1, very interesting issue with a lot of material, thank you for taking the time of writing it. I ended up here because I am also an unsloth user and I was wondering how the context length was handled in torchtune as I never provide infos about the MAX_SEQ_LEN.

What is the state of it in december 2024 ? Do you need help for something ?

felipemello1 commented 20 hours ago

hey @lulmer , glad to hear you are interested!

I don't have the exact number, but depending on the model size and tokenizer.max_seq_len, i believe you can get up to 128k in a single device (80GiB) using all of our flags. You can see them here: https://github.com/pytorch/torchtune#optimization-flags

Regarding needing help, we always love when people contribute! We don't have any easy / low hanging fruit for long context specifically, but we have many tickets marked as "community help wanted" and "good first issue". If you would like to get used to the codebase and submitting some PRs, then I am sure that complex projects will appear. Let me know if you had any in mind. :)

felipemello1 commented 14 hours ago

@lulmer, if are interested in exploring LongLora (or another technique), please feel free to open an RFC with some naive code snippets. We can review it, give some pointers, and if it makes sense, it would be great to have you as a contributor.