song-wx / SIFT

[ICML2024 Spotlight] Fine-Tuning Pre-trained Large Language Models Sparsely
13 stars 2 forks source link

[ICML 2024 Spotlight] SIFT: Sparse Increment Fine-Tuning

The repository contains the implementation for the paper: Sparse is Enough in Fine-tuning Pre-trained Large Language Models and the introduction of the general usage of SIFT in different demands.

Sparse is Enough in Fine-tuning Pre-trained Large Language Models
Weixi Song*, Zuchao Li*, Lefei Zhang, Hai Zhao, Bo Du\ Paper: https://arxiv.org/abs/2312.11875

Contents

Introduction

implementation

In this work, we present a compoent-sparse and memory-efficient updating scheme (SIFT). Inspired by the memory-efficient SGD implementation in LOMO, we implement a component-sparse updating scheme(SIFT) by injecting hook in the backward propagation. See our paper for more details. The main code of SIFT is in sift.py

Through this method, for x% sparse updates, we can simultaneously reduce the memory consumption of gradients and optimizer states to the original x%. Combined with techniques such as mixed-precision training and gradient checkpointing, it is able to fine-tune a 7B model on a single RTX 3090 24GB.

memory comsumption

We provide several use cases in Natural Language Processing and it can be applied to different areas in the same way. See exp for experiments in GLUE benchmark and the Instruction-tuning task. The experiments are built on the orginal repositories of Transformers, Alpaca and MMLU. HumanEval Evaluation is conducted in code-eval. Thanks for these great works.

instruction
glue

Install

git clone git@github.com:song-wx/SIFT.git
cd SIFT
pip install .

Please solve the dependency issues as needed.

Usage

Note: The current implementation only considers training in a single card. If you are interested in training in multiple cards, please modify the code to fit your demand.

Basic usage

Step 1: After initializing your model, run the following code to specify the parameters that need to be updated sparsely by setting sparse_module and sparse_rate to customize the sparse training and also you can specify the module to be updated normally by setting exception.

## initialize your model
model = ...

## initialize SIFT 
from sift import SIFT

sift = SIFT(model, sparse_rate=0.01, 
            sparse_module=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
            grad_acc=gradient_accumulation,
            gradient_checkpointing=gradient_checkpointing)

## you can print the actual trainable numbers in SIFT
sift.print_trainable_parameters()

Step 2: Initialize the optimizer with the actual trainable parameters sift.named_parameters_in_optimizer() in SIFT.

## example
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
            {
                "params": [p for n, p in sift.named_parameters_in_optimizer() if not any(nd in n for nd in no_decay) ] ,
                "weight_decay": weight_decay,
            },
            {
                "params": [p for n, p in sift.named_parameters_in_optimizer() if any(nd in n for nd in no_decay) ] ,
                "weight_decay": 0.0,
            },   
        ]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=learning_rate)

Step 3: run the training loop normally with model and optimizer

## if use Trainer
trainer = Trainer(model=model, optimizer=(optimizer, None), ...)
trainer.train()

## if use bare training loop, it is the same as the normal training process.
for i, batch in enumerate(dataloader):
    output = model(**batch)
    loss = ...
    loss.backward()
    optimizer.step()
    optimizer.zero_()

Customized usage

SIFT essentially creates an additional sparse parametersparse_param for each parameterp that needs to be sparsely updated, which is represented in the indexes sparse_param.idx and the values sparse_param.data. After initializing SIFT, you can get the sparse parameter sparse_param of a target parameter p with the name n by using the dict sift.sparse_mapping[n].

Customize the selection of indexes

In our paper, we propose a gradient-based selection method based on our finding of the quasi-sparse gradient distribution of the pre-trainde model. We determine the indexes as the components whose absolute gradient of the first few batches are in the top x%.

sparse_idx = torch.flatten(abs(grad)).topk(sparse_param.train_num).indices.cpu().numpy() 
sparse_param.idx = np.stack(np.unravel_index(sparse_idx, p.shape))

We compare the efficiency of this gradient-based method with LoRA and random selection in different quotas of the trainable parameters. You can modify the above codes in sift.py to customize your index selection.

efficiency

Store in a memory-effient way

Due to SIFT merging sparse_param into the original p in the hook to ensure the correct forward propagation(as the following codes), the final updated parameters are the original parameters p. If you want to store in a memory-effient way, you can store the partial components of p with sparse_param.idx otherwise we save the complete p.

## update the initial param sparsely
delta = p.data + torch.sparse_coo_tensor(sparse_param.idx, sparse_param, p.shape).to(p)
p.data.copy_(delta)  
sparse_param.zero_()

Citation

@inproceedings{
song2024sparse,
title={Sparse is Enough in Fine-tuning Pre-trained Large Language Models},
author={Weixi Song and Zuchao Li and Lefei Zhang and hai zhao and Bo Du},
booktitle={Forty-first International Conference on Machine Learning},
year={2024},
url={https://openreview.net/forum?id=10hu2D3hAg}
}