pytorch / torchtune

A Native-PyTorch Library for LLM Fine-tuning
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
3.94k stars 357 forks source link

[RFC] Long context fine tuning in torchtune #1244

Open felipemello1 opened 1 month ago

felipemello1 commented 1 month ago

Goal


This RFC aims to define which type of solution to prioritize so we can bridge this gap.

What number do we hope to achieve, with which model and which GPU?

There are three ways I propose we think about it:

Sanity check - bsz, num_steps, memory required

Using llama2-7b (~half of the memory llama3-8b needs):

LongLora (https://arxiv.org/pdf/2309.12307) finetunes for 1000 steps (bsz=64) using 8x100, DDP,

image

YaRN paper (https://arxiv.org/pdf/2309.00071) does it for 400 steps (bsz=64, seq_len=64k, +200 steps for seq_len=128k)

Ring attention and Memory Efficient Attention (https://arxiv.org/pdf/2310.01889)

image

Why


image Source: https://www.reddit.com/r/LocalLLaMA/comments/18t63ub/is_context_length_32k_actually_useful_to_you/

Below are some use cases, but an argument we can make is that supporting long-context will enable new cases, even if they don't exist today.

Possible data flavors:

Possible use cases:

What about RAG?

RAG works in some scenarios and is cheap, but:

However, long context and RAG can be orthogonal approaches: with long context, people can return longer documents, instead of small paragraphs. In other words, long context enables new RAG capabilities.

Our current snapshot:


Currently we support at most ~24k with llama 8b, on A100-80GB, for LoRa, QLoRA, FFT with or without FSDP. After some memory optimizations, we can probably increase it by a little, but probably not enough for 32k.

image

You may think: Shouldn’t QLoRA support much more than FFT? At higher context lengths, the activations consume most of the memory.

Our levers


1) Inference solutions, such as:

(1) is not relevant, since our focus is finetuning. (2) not widely adopted and it is risky to add new modules that change its inference behavior. (3) We already support QLoRA, 8bitoptmizers. The issue is the quadratic growth of activation, and such methods don't tackle it. (4) We rely on pytorch for it. This leaves us with (5) and (6).

LongLora:

image

Pros:

Cons:

Pseudocode image

Ring Attention

Pseudocode

def naive_soft_max(x):
    return x.exp() / x.exp().sum()

# done in vanilla attention
target = naive_softmax(x) 

# softmax over chunks
x1,x2 = torch.chunk(x,2)
softmax1 = naive_softmax(x1)
softmax2 = naive_softmax(x2)

# incremental softmax
sum_exp_1 = x1.exp().sum()
sum_exp_2 = x2.exp().sum()

softmax1_corrected = softmax1*sum_exp_1 / (sum_exp_1 + sum_exp_2)
softmax2_corrected = softmax2*sum_exp_2 / (sum_exp_1 + sum_exp_2)

softmax_combined = torch.cat([softmax1_corrected, softmax2_corrected])
torch.all_close(target, softmax_combined)

image Source: https://www.youtube.com/watch?v=ws7angQYIxI

Pros:

Cons:

Code pointers: Pytorch: https://github.com/pytorch/pytorch/pull/129515 Torchtitan: https://github.com/pytorch/torchtitan/pull/433

Current constraints of pytorch implementation that they are actively working on:

Suggestion

My suggestion is to focus on ring attention experimental implementation by pytorch, since it scales well and will be supported by them. If their CP implementation is optimal, then 8xA100 would allow us to get to 8x24k = 192k context length. Most likely LongLora wouldn't be enough to reach 64k with llama3.

Timeline


According to our PyTorch PoC, the code to implement context parallelism is short, and can be done in a week. Therefore, I propose the following:

Week July 29th - Formalize script to, given a model, find max_seq_length for different settings, so we can test at 16GB, 24GB, 40GB, 80GB Week Aug 5th - Understand pytorch CP implementation Week Aug 11th - Proof-of-concept for llama3.1 8b, stress test max seq len Week Aug 19th - Propose and discuss implementation in torchtune Week Aug 26th - Evaluation on key long context datasets Week Sept 2nd - Optimize + unit test (maybe do it before eval?) Week Sept 9th - Implement for the rest of the models Week Sept 16th - Tutorials + best practices

The timeline should look similar for LongLora. There is a risk in focusing too much on llama 3.1 and having something that needs refactoring to work for other models. We should try to minimize this risk and make it model independent, but prioritizing a family of models will allow us to iterate faster, learn and make adjustments later.

Evaluation


image source: https://arxiv.org/pdf/2404.02060

Therefore, for evaluating our trained models I propose three strategies (text only):

As a starter, we can use LongAlpaca (https://huggingface.co/datasets/Yukang/LongAlpaca-12k), since it was also used by LongLora. It contains 3k short examples and 9k long examples from the Alpaca dasetet, ranging from 35 characters to 191k.

This strategy can be improved during the evaluation focus time.

joecummings commented 1 month ago

Same as above, but aim for 64k, and rely on RoPE scaling to extend context further

This confuses me as I consider RoPE scaling and these techniques to go hand-in-hand.

felipemello1 commented 1 month ago

right! This is more about our "success criteria". Do we want to finetune 128k max, or would we be comfortable to get to 64k max, and 128k can be achieved through scaling.

joecummings commented 1 month ago

reddit.com/r/LocalLLaMA/comments/18t63ub/is_context_length_32k_actually_useful_to_you

This is awesome motivation :)

joecummings commented 1 month ago

right! This is more about our "success criteria". Do we want to finetune 128k max, or would we be comfortable to get to 64k max, and 128k can be achieved through scaling.

Ahhh I see, I see. Personally, I think the latter is what I consider success.

joecummings commented 1 month ago

Albania Wikipedia Page: 43k tokens

I was just thinking about how I desperately want to finetune a model on the Albania Wikipedia page...

joecummings commented 1 month ago

Maybe a dumb question, but Ring Attention really only helps us improve context length based on num of GPUs...so a single GPU QLoRA would never be able to get to 128k under this plan. Is that correct? Are we just saying we don't have the BW to take on single GPU long context? If so, I want this to be explicit and called out here.

joecummings commented 1 month ago

How does this RFC (#1183) on RoPE scaling fit in here? Again, in my mind they are related, but it doesn't seem like this is accounted for in your work timeline.

felipemello1 commented 1 month ago

so a single GPU QLoRA would never be able to get to 128k under this plan. Is that correct?

correct, probably for the same reason we don't have 405b for a single GPU. It is possible to do long(er) context in a single GPU using the same ring-attention concept, but in series (you chunk QKV, and compute it in blocks in a for loop). It is the "memory efficient attention" https://arxiv.org/abs/2112.05682

implementation: https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/d54f391370ecbf843a871f0e260425d076995550/memory_efficient_attention_pytorch/memory_efficient_attention.py#L119

However, it is quadratically slower, and most likely still not enough to get to 128k.

felipemello1 commented 1 month ago

How does this RFC (#1183) on RoPE scaling fit in here? Again, in my mind they are related, but it doesn't seem like this is accounted for in your work timeline.

Good question! Since we released llama 3.1 that has support for 128k, I thought that this was a solved problem, and that RFC was more about design discussion. Are you still going to implement it? If so, would it be ready before the timeline: "Week Sept 9th - Implement for the rest of the models"

ebsmothers commented 1 month ago

Thanks for the detailed RFC! Overall I agree with the proposal to start trying out context parallel support using the latest APIs from distributed. I do have two main comments on a first pass (though I reserve the right to come back later with more 😅 ).

(1) Sequencing: I would like to get to "Proof-of-concept for llama3.1 8b, stress test max seq len" as fast as possible. This is most likely a prerequisite for any asks we have to distributed or any other memory-focused work we need to do. So I would suggest frontloading this portion of the work, imo the script and design discussions are moot until we have that. Even understanding the distributed CP implementation can be done in parallel

(2) We discussed this offline a bit already, but I would love to see an effort to increase the max context length on a single device as well. 24k is not great imo so we should explore some of the other techniques you described (along with other memory-saving techniques) in parallel. This is not necessarily something you need to consider as part of the context parallelism effort so nbd for this particular RFC, just want to call it out.

felipemello1 commented 1 month ago

So I would suggest frontloading this portion of the work

sounds good, I changed it to go before the design

RdoubleA commented 1 month ago

So is my understanding correct: for long context, ring attention / CP is the most promising but will require X number of GPUs, and if you want to use a single GPU you need to rely on RoPE scaling?

Also, do you have an idea of how ring attention will be called in torchtune? Does it happen under the hood with SDPA if we're using DTensors (based on my brief skim through the code snippets)?

felipemello1 commented 1 month ago

for long context, ring attention / CP is the most promising but will require X number of GPUs. Whatever context length you can fit in one gpu, you can make it times n_gpus.

Correct

if you want to use a single GPU you need to rely on RoPE scaling

It doesnt fit in memory. Its like if you wanted to fit 405b in a single GPU. Its just too much, unless you do a lot of cpu/gpu wizardry, but then its slow

Also, do you have an idea of how ring attention will be called in torchtune? Does it happen under the hood with SDPA if we're using DTensors (based on my brief skim through the code snippets)?

I will be able to answer it better next week

@RdoubleA