CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.
Apache License 2.0
1.44k stars 87 forks source link

Support for gradient_checkpointing #9

Open Richar-Du opened 1 year ago

Richar-Du commented 1 year ago

Thanks for your awesome work! There is a small problem: when I fine-tune long_llama with gradient_checkpointing, it raises an error: image Could you please update the code in transformers to make long_llama support gradient_checkpointing. I think it is useful for the community to use long_llama. @CStanKonrad

CStanKonrad commented 1 year ago

Hi, thanks for the request. In the recent commit, I have added initial support for gradient checkpointing (it just skips memory layers). As I am writing, it is not yet present in the Hugging Face repository, so to use it you can download code from the src directory in this repository and write something like this:

from transformers import LlamaTokenizer
from .modeling_longllama import LongLlamaForCausalLM
import torch

MODEL_PATH = "syzymon/long_llama_3b"

tokenizer = LlamaTokenizer.from_pretrained(MODEL_PATH)
model = LongLlamaForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float32)
Richar-Du commented 1 year ago

Thanks for your commit!

Now I would like to fine-tune longllama, but the sequence is too long and it returns CUDA OOM (4x80G). I wonder if I could fine-tune longllama under a regular framework without support for long context (e.g. the training framework of alpaca or vicuna). If I could not, could you please release the fine-tuning code of longllama?

CStanKonrad commented 11 months ago

I apologize for the late response. We have recently published the code that allows for fine-tuning the model on a single A100 80GB GPU. We use a total context size of 2048, with last_context_length being 1024. For shorter inputs, we randomly decide how much data will be present in memory. We achieve this by randomly padding the input.

You can try the instruction+chat fine-tuned model in the Colab.

For the Colab model, we provide the fine-tuning config and log of train loss.