CStanKonrad / long_llama

LongLLaMA is a large language model capable of handling long contexts. It is based on OpenLLaMA and fine-tuned with the Focused Transformer (FoT) method.
Apache License 2.0
1.45k stars 85 forks source link

Help: questions about training on 8k input text length #23

Closed Force1ess closed 10 months ago

Force1ess commented 10 months ago

Hi, long_llama is very surprising according to the report in the paper, and thank u for ur great work. I re I'm interested in training Long_Llama-3b on some long text corpus. But out of memory error is very usual on my A100-80G. Is there any solutions to finetune this model on text of 10k length? Do you have any idea about reducing memory usage? I noticed in your paper that you mentioned the model was trained at a length of 8k. Can u share ur training script so I can learn from it?

Below is my trainning script

#!/bin/bash
EXP_NAME="example_inst_ft_3b_low_budget"
accelerate   launch  -m instruction_fine_tuning.fine_tuning \
    --run_name "$EXP_NAME" \
    --ddp_find_unused_parameters False\
    --output_dir "$EXP_NAME"/ \
    --model_path "/mnt/shared_home/zhenghao2022/FormatGPT/long_llama/longllama" \
    --torch_dtype bfloat16\
    --data_type "instructions" \
    --data_path "/mnt/shared_home/zhenghao2022/FormatGPT/long_llama/data/18-09-CC-NEWS-20180929021529-00549.json" \
    --data_revision "f0823c7ffc48c9d33a42f16cf0b885fed4a7d0a1" \
    --dataset_split "train" \
    --prompt_field "system_prompt" \
    --post_prompt_text "
" \
    --question_field "question" \
    --post_question_text "
" \
    --response_field "response" \
    --last_context_length 768 \
    --max_input_length 4096\
    --max_output_length 4096\
    --max_total_length 10240\
    --always_pad False\
    --random_pad True \
    --max_steps 300\
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 1\
    --learning_rate 1.0e-5 \
    --weight_decay 0. \
    --warmup_steps 100 \
    --lr_scheduler_type "cosine" \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_total_limit 3 \
    --save_steps 500 \
    --logging_strategy "steps" \
    --logging_steps 25 \
    --gradient_checkpointing False \
    --tf32 True \
    --bf16 True
CStanKonrad commented 10 months ago

Hey! Thank you for your question. For long context training, we have used the JAX code which is more optimized for training with longer input sequences. The PyTorch code was initially intended for inference, however, if we find time then we will happily optimize it further. The two main optimizations in jax code are

  1. gradient checkpointing of memory layers
  2. this attention implementation inspired by flash attention that is enabled by flag scan_cross_batch - the main idea is not to materialize the whole ∇ Query x Key
Force1ess commented 10 months ago

Thank u very much for your patient reply It helped a lot