We introduce Fira, a plug-and-play memory-efficient training framework of LLMs.
Different from LoRA and Galore, we realize training with full-rank gradients of full-rank weights, constituting the first attempt to achieve full-rank training consistently under the low-rank constraint.
Our method is easy to implement, basically relying on just two lines of equations.
pip install fira
from fira import FiraAdamW, divide_params
param_groups = divide_params(model, target_modules_list = ["attn", "mlp"], rank=8)
optimizer = FiraAdamW(param_groups, lr=learning_rate)
Please add the module names that need to enable our Fira in target_modules_list
(substrings are acceptable).
We also provide a quick-start tutorial for the Fira optimizer. You can find it in ./quick_start
.
In Fira, Adam is used by default with weight_decay=0
.
If you want to enable weight decay for AdamW, set as follows:
optimizer = FiraAdamW(param_groups, lr=learning_rate, weight_decay=0.01)
Besides, you can modify the learning rate according to different tasks, with a recommended range of $10^{-5}$ to $10^{-2}$.
./pre_training_c4
includes the code for pre-training LLaMA models on the C4 dataset.
cd pre_training_c4
pip install -r requirements.txt
Our experiment scripts are validated on Python 3.9 with PyTorch 2.2.2.
./pre_training_c4/torchrun_main.py
script is used for pre-training LLaMA models on the C4 dataset.
./pre_training_c4/scripts
directory stores the benchmark scripts across different LLaMA model sizes (60M, 130M, 350M, 1B, 7B).
For instance, to pre-train a 60M model on C4 dataset, execute the following command:
# LLaMA-60M, Fira-Adam, 1 A100, 1 Node
torchrun --standalone --nproc_per_node 1 torchrun_main.py \
--model_config llama_configs/llama_60m.json \
--lr 0.01 \
--alpha 0.25 \
--rank 128 \
--update_proj_gap 200 \
--batch_size 256 \
--total_batch_size 512 \
--num_training_steps 10000 \
--warmup_steps 1000 \
--weight_decay 0 \
--dtype bfloat16 \
--eval_every 1000 \
--optimizer fira_adamw
This script directly accesses huggingface to load the C4 dataset, so please ensure a stable internet connection.
Alternatively, you can refer to the tutorials in ./download_use_c4
for using a local dataset.
./fine_tuning
includes the code for fine-tuning LLaMA-7B with Fira.
cd fine_tuning
pip install -r requirements.txt
Download commonsense 170k finetuning dataset from LLM-Adapters. Then, place it as ./fine_tuning/commonsense_170k.json
.
Download full dataset directory from LLM-Adapters. Then, place it as ./fine_tuning/dataset
.
./finetune.py
is used for finetuning LLaMA-7B on the commonsense reasoning tasks.
./commonsense_evaluate.py
is used for evaluating the finetuned LLaMA-7B model on 8 sub-tasks of the commonsense reasoning tasks.
For instance, to finetuning LLaMA-7B with Fira on the commonsense reasoning tasks by a single GPU, execute the following command:
# LLaMA-7B, Fira-Adam, 1 4090
CUDA_VISIBLE_DEVICES=0 python finetune.py \
--base_model 'yahma/llama-7b-hf' \
--data_path 'commonsense_170k.json' \
--output_dir './result/fira' \
--batch_size 16 \
--micro_batch_size 4 \
--num_epochs 3 \
--learning_rate 1e-4 \
--cutoff_len 256 \
--val_set_size 120 \
--adapter_name lora \
--lora_r 32 \
--lora_alpha 64 \
--use_gradient_checkpointing \
--target_modules '["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"]' \
--save_step 15000 \
--eval_step 1000 \
--optimizer_name fira_adamw
For instance, evaluate the finetuned LLaMA-7B model on the BoolQ sub-task:
# LLaMA-7B, Fira-Adam, 1 4090
CUDA_VISIBLE_DEVICES=0 python commonsense_evaluate.py \
--model LLaMA-7B \
--adapter LoRA \
--dataset boolq \
--batch_size 1 \
--base_model 'yahma/llama-7b-hf' \
--lora_weights './result/fira' | tee -a './result/fira/boolq.txt'
To further substantiate our findings of the scaling factor, we conduct more quantitative analysis of scaling factor similarities between low-rank and full-rank LLMs training. Specifically, we assess scaling factor similarities at both matrix and column level for pre-training LLaMA models ranging from 60M to 1B, averaged over 10,000 steps.
Size | Matrix Level | Column Level | ||||||
---|---|---|---|---|---|---|---|---|
Spearman | Kendall | Spearman | Kendall | |||||
Coefficient | P-value | Coefficient | P-value | Coefficient | P-value | Coefficient | P-value | |
60M | 0.9972 | 2e-62 | 0.9662 | 7e-26 | 0.9372 | 0.0 | 0.7942 | 0.0 |
130M | 0.9925 | 2e-76 | 0.9409 | 9e-37 | 0.8698 | 0.0 | 0.6830 | 0.0 |
350M | 0.9770 | 3e-113 | 0.8848 | 5e-65 | 0.9091 | 0.0 | 0.7400 | 0.0 |
1B | 0.9469 | 1e-83 | 0.8249 | 1e-56 | 0.8331 | 0.0 | 0.6513 | 0.0 |
Spearman and Kendall correlation coefficients range from -1 to +1, +1 signifies a perfect positive correlation, and -1 signifies a perfect negative correlation. Generally, a p-value below 0.05 suggests that a significant correlation exists. As shown in the above table, both Spearman and Kendall correlation coefficients indicate a strong positive relationship at the matrix and column levels across all sizes of the LLaMA models, with all p-values below 0.05.
Therefore, it is likely that the observed behavior is an inherent feature of LLM training, manifesting across a broad range of scenarios. This insight provides a robust experimental basis for our proposed norm-based scaling in Fira and helps explain its effectiveness. Code for this analysis is provided in ./similarity
.
This implementation is based on code from several repositories.
@article{chen2024firaachievefullranktraining,
title={Fira: Can We Achieve Full-rank Training of LLMs Under Low-rank Constraint?},
author={Xi Chen and Kaituo Feng and Changsheng Li and Xunhao Lai and Xiangyu Yue and Ye Yuan and Guoren Wang},
journal={arXiv},
year={2024},
}