Lightning-AI / litgpt

20+ high-performance LLMs with recipes to pretrain, finetune and deploy at scale.
https://lightning.ai
Apache License 2.0
9.93k stars 990 forks source link

LongLora fine-tuning support #1237

Open belerico opened 5 months ago

belerico commented 5 months ago

LongLora is "an efficient fine-tuning approach that extends the context sizes of pre-trained large language models". They propose to fine-tune a model with a sparse local attention while maintaining dense attention during inference. The Shifted-Sparse Attention (S^2-Attn) is depicted in the following (from the paper):

image

Moreover, the implied modification is relatively simple:

# B: batch size; 
# S: sequence length or number of tokens; 
# G: group size;
# H: number of attention heads; 
# D: dimension of each attention head
# qkv in shape (B, N, 3, H, D), projected queries, keys, and values

# Key line 1: split qkv on H into 2 chunks, and shift G/2 on N
qkv = cat((qkv.chunk(2, 3)[0], qkv.chunk(2, 3)[1].roll(-G/2, 1)), 3).view(B*N/G,G,3,H,D)

# standard self-attention function
out = self_attn(qkv)

# out in shape (B, N, H, D)
# Key line 2: split out on H into 2 chunks, and then roll back G/2 on N
out = cat((out.chunk(2, 2)[0], out.chunk(2, 2)[1].roll(G/2, 1)), 2)

This can be effectively enabled only during the fine-tuning phase while the standard dense attention can be used during inference.

Another thing that should be modified is the padded sequence length, which should be a multiple of the group-size.

If you think that this can be added to lit-gpt, I'm willing to contribute with a PR (I've already something working which I plan to test)

Edit:

I forgot to mention that they also use the Position Interpolation to rescale the position indices. If I'm not mistaken this can be achieved by simply change the rope_condense_ratio to account for the increased contex-size

belerico commented 5 months ago

I've put together something here.

To further reduce the memory consumption I've also added the possibility to remove the last n layers in the model, as specified in "The Unreasonable Ineffectiveness of the Deeper Layers", Sec. 4.4.

I've trained a model on a Lightning Studio on a A10G with the following hyperparameters:

litgpt finetune lora \
--config "/teamspace/studios/this_studio/litgpt/config_hub/finetune/mistral-7b/lora.yaml" \
--lora_r 16 \
--lora_dropout 0 \
--lora_query true \
--lora_key true \
--lora_value true \
--lora_projection true \
--longlora_n_groups 4 \
--longlora_context_length 8192 \
--logger_name "tensorboard" \
--data.pad_multiple_of 4 \
--checkpoint_dir "/teamspace/studios/this_studio/checkpoints/mistralai/Mistral-7B-Instruct-v0.1/" \
--train.micro_batch_size 1 \
--train.max_seq_length 8192 \
--train.remove_last_perc_layers 0.0 \
--train.get_longest_seq_length true \
--train.trainable_params "wte,norm_" \
--precision "bf16-true"

With those settings I have:

Those are two generations:

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie to watch on the weekend.

### Response:
One great movie to watch on the weekend is "The Shawshank Redemption". It's a classic drama with a compelling story, great acting, and a positive message.

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Recommend a movie to watch on the weekend.

### Response:
One great movie to watch on the weekend is "The Shawshank Redemption". It's a timeless classic with a compelling storyline and excellent performances.
rasbt commented 5 months ago

Thanks for sharing and writing-up this thorough description. I saw the paper a few months back but must admit that I didn't have time to read.

Btw I am all in terms of supporting interesting research techniques that help around real & common issues (e.g., high memory requirements, LLMs not being able to handle long contexts, etc.)

In general, something I am wondering about is if it's really LoRA specific, or could it also be used with "full"-parameter finetuning?

The --train.remove_last_perc_layers is also a nice to have. I'd would apply it in a separate PR, and I think it's useful to have.

What do you think @awaelchli @carmocca ?

belerico commented 5 months ago

Hi @rasbt,

In general, something I am wondering about is if it's really LoRA specific, or could it also be used with "full"-parameter finetuning?

Even though in the paper they have specifically designed everything for fine-tuning with LoRA, it's something that I also thought about. The concern that I have is the flow of information between the first and the last token, which is mitigated during the fine-tuning since the pre-training has already been done on a ton of data. The authors ablate this in B.3 and found that it doesn't influence the finetuning. Maybe it can be applied during pre-training by adopting the Variant-2 in B.3, where they use a separate group in the shifted tokens?

The --train.remove_last_perc_layers is also a nice to have. I'd would apply it in a separate PR, and I think it's useful to have.

Sure

belerico commented 5 months ago

Hi guys, I'm catching up here. I've spotted a little bug due to a missing reshape and I've also implemented LongLora for the full-finetune. If it's ok from you I'll open a PR

cc @rasbt @carmocca @awaelchli

rasbt commented 5 months ago

To me both would be welcome and valuable contributions :). I would maybe do both in separate PRs as that would make the code review a bit easier