CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.
Apache License 2.0
1.45k stars 85 forks source link

Need clarification on token limit of input used for fine tuning #20

Open lokesh-iterate opened 1 year ago

lokesh-iterate commented 1 year ago

Hi, I am going through the page:https://huggingface.co/syzymon/long_llama_code_7b_instruct. I found the text "All inputs were truncated and randomly padded (left/right) to 3072 tokens" under Training. Is there a reason behind this truncation? . I have noticed in creating instruct version of the model from long llama model, the context lengths used for finetuning the model are significantly smaller than the context length provided during inference time. I like to get this clarification because I have prepared a dataset in the format similar to jsonl file in the link: https://github.com/chrishayuk/opl-train/blob/main/JSONL/train.jsonl . Here each line belong to one input. Several of the input in my custom jsonl file have larger tokens like ~15K. If I follow the FT script you provided here : https://github.com/CStanKonrad/long_llama/tree/main/instruction_fine_tuning for my dataset, will the inputs with longer tokens get trucated or ignored after certain tokens during finetuning?. Or is there a possibility that during finetuning process my larger input gets split into several windows ?

CStanKonrad commented 1 year ago

Thank you for your question!

The reason behind this is that most examples from OpenOrca and MathInstruct should fit within this context length (only chat examples were longer, but as they were sampled with a probability of 0.01 we decided to truncate them).

In the instruction tuning code, examples are truncated (take a prefix and discard the rest): https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L151 https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L257

To handle both long and short inputs without running all with 15K context you can disable the always_pad option https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/example_instchat_ft_3bv1.1_low_budget.sh#L36 and set the max_*_length to the maximum value in the dataset

then the DataColloator will pad to the maximum length within the batch of examples https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L727

The parameters

last_context_length < max_*_length 
always_pad True
random_pad True

were used to simulate long context scenarios using short context data (so that the model won't forget how to use memory layers, here random padding decides for short examples how much goes to the memory and how much to the last local context). This is possible as FoT assigns the same positional encoding to all tokens from the memory (in memory layers). (In fact, one can ask about positional encodings in the local context as they are less utilized in our case, but we assume that the model got well accustomed to them during the whole pre-training procedure and won't forget how to use them). Note that by How LongLLaMA handles long inputs parameter last_context_length should be always <= max_position_embeddings from https://huggingface.co/syzymon/long_llama_code_7b/blob/main/config.json (in case you are tuning this model).

Note that in the FoT continued-pretraining code we do not truncate (take a prefix and discard the rest) long documents but instead move the remaining parts to the next batch. This is because standard language modeling on parts of the documents still makes sense whereas answering an unknown question may not make sense. https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/fot_continued_pretraining/FoT/data_pipeline.py#L152

lokesh-iterate commented 1 year ago

Hi CStanKonrad, Thank you for the details. It gave me more clarification. Currently I am going through you fine tuning code that was used to create instruct version of the model(https://github.com/CStanKonrad/long_llama/blob/main/instruction_fine_tuning/fine_tuning.py) and arguments setup (https://github.com/CStanKonrad/long_llama/blob/main/instruction_fine_tuning/arguments.py). Correct me if I am wrong, you haven't set " mem_attention_grouping " in your fine tuning code for instruct version right? Is there a reason to it?.If you havent set this parameter, how would the model remember to all the memory layers while you are fine tuning on a specific dataset. My follow up question is, do you recommend Qlora (4 bit) for long llama 7B version with all the changes for including longer dataset as per your previous reply?

Should I set "mem_layers" parameter like "mem_layers": [ 8, 16, 24] , when I load the model for fine tuning? or are these parameters supposed to be used only during inference stage ? Thank you