Open lokesh-iterate opened 1 year ago
Thank you for your question!
The reason behind this is that most examples from OpenOrca and MathInstruct should fit within this context length (only chat examples were longer, but as they were sampled with a probability of 0.01 we decided to truncate them).
In the instruction tuning code, examples are truncated (take a prefix and discard the rest): https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L151 https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L257
To handle both long and short inputs without running all with 15K context you can disable the always_pad option https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/example_instchat_ft_3bv1.1_low_budget.sh#L36 and set the max_*_length to the maximum value in the dataset
then the DataColloator will pad to the maximum length within the batch of examples https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/instruction_fine_tuning/data_processing.py#L727
The parameters
last_context_length < max_*_length
always_pad True
random_pad True
were used to simulate long context scenarios using short context data (so that the model won't forget how to use memory layers, here random padding decides for short examples how much goes to the memory and how much to the last local context). This is possible as FoT assigns the same positional encoding to all tokens from the memory (in memory layers). (In fact, one can ask about positional encodings in the local context as they are less utilized in our case, but we assume that the model got well accustomed to them during the whole pre-training procedure and won't forget how to use them).
Note that by How LongLLaMA handles long inputs parameter last_context_length
should be always <= max_position_embeddings
from https://huggingface.co/syzymon/long_llama_code_7b/blob/main/config.json (in case you are tuning this model).
Note that in the FoT continued-pretraining code we do not truncate (take a prefix and discard the rest) long documents but instead move the remaining parts to the next batch. This is because standard language modeling on parts of the documents still makes sense whereas answering an unknown question may not make sense. https://github.com/CStanKonrad/long_llama/blob/6c21666ddf19e0a7a3aabaa77dc47b4f4f7bbdb2/fot_continued_pretraining/FoT/data_pipeline.py#L152
Hi CStanKonrad, Thank you for the details. It gave me more clarification. Currently I am going through you fine tuning code that was used to create instruct version of the model(https://github.com/CStanKonrad/long_llama/blob/main/instruction_fine_tuning/fine_tuning.py) and arguments setup (https://github.com/CStanKonrad/long_llama/blob/main/instruction_fine_tuning/arguments.py). Correct me if I am wrong, you haven't set " mem_attention_grouping " in your fine tuning code for instruct version right? Is there a reason to it?.If you havent set this parameter, how would the model remember to all the memory layers while you are fine tuning on a specific dataset. My follow up question is, do you recommend Qlora (4 bit) for long llama 7B version with all the changes for including longer dataset as per your previous reply?
Should I set "mem_layers" parameter like "mem_layers": [ 8, 16, 24] , when I load the model for fine tuning? or are these parameters supposed to be used only during inference stage ? Thank you
Hi, I am going through the page:https://huggingface.co/syzymon/long_llama_code_7b_instruct. I found the text "All inputs were truncated and randomly padded (left/right) to 3072 tokens" under Training. Is there a reason behind this truncation? . I have noticed in creating instruct version of the model from long llama model, the context lengths used for finetuning the model are significantly smaller than the context length provided during inference time. I like to get this clarification because I have prepared a dataset in the format similar to jsonl file in the link: https://github.com/chrishayuk/opl-train/blob/main/JSONL/train.jsonl . Here each line belong to one input. Several of the input in my custom jsonl file have larger tokens like ~15K. If I follow the FT script you provided here : https://github.com/CStanKonrad/long_llama/tree/main/instruction_fine_tuning for my dataset, will the inputs with longer tokens get trucated or ignored after certain tokens during finetuning?. Or is there a possibility that during finetuning process my larger input gets split into several windows ?