princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
533 stars 39 forks source link
efficiency llama llama2 llm nlp pre-training pruning

🦙 Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning

🌟 ArXiv Preprint | Blog Post

Base models: Sheared-LLaMA-1.3B | Sheared-LLaMA-2.7B | Sheared-Pythia-160m
Pruned Models without Continued Pre-training: Sheared-LLaMA-1.3B-Pruned, Sheared-LLaMA-2.7B-Pruned
Instruction-tuned models: Sheared-LLaMA-1.3B-ShareGPT | Sheared-LLaMA-2.7B-ShareGPT

Thank you for your interest in our work! This is a joint work by Mengzhou Xia, Tianyu Gao, Zhiyuan Zeng, and Danqi Chen. Here, we provide our codebase for Sheared-LLaMA's pruning and continued pre-training algorithms :) We find that pruning strong base models is an extremely cost-effective way to get strong small-scale language models compared to pre-training them from scratch. The following graph shows that given the existence of Llama-2-7B model (pre-trained with 2T tokens), pruning it produces a model as strong as an OpenLLaMA model with 3% of its pre-training cost.

teaser

Update

🔗 Quick Links

Brief Introduction

This codebase is built based on MosaicML's amazing Composer package, which is specially designed and optimized for large language model pre-training. The entire implementation, including the pruning logic and the dynamic batch loading logic, are implemented as callback functions without touching the vanilla Composer trainer. Here's a concise overview of each folder within the codebase:

Install Requirements

Step 1: To get started with this repository, you'll need to follow these installation steps. Before proceeding, make sure you have Pytorch and Flash Attention installed. You can do this via pip using the following commands:

pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install flash-attn==1.0.3.post

Please note that Flash Attention version 2 is not currently supported and may require manual modifications to the model file.

Step 2: Then install the rest of the required packages:

cd llmshearing
pip install -r requirement.txt

Step 3: Finally, install the llmshearing package in editable mode to make it accessible for your development environment:

pip install -e .

Data Preparation

Please refer to llmshearing/data for details on how to prepare data with Mosaicml's Streaming package.

Model Preparation

To utilize Hugging Face transformer models with Composer, you'll need to convert the model weights to the key format expected by Composer. Here's an example of how to convert the weights from the Hugging Face model 'llama2' into a compatible format for Composer:

# Define the Hugging Face model name and the output path
HF_MODEL_NAME=meta-llama/Llama-2-7b-hf
OUTPUT_PATH=models/Llama-2-7b-composer/state_dict.pt

# Create the necessary directory if it doesn't exist
mkdir -p $(dirname $OUTPUT_PATH)

# Convert the Hugging Face model to Composer key format
python3 -m llmshearing.utils.composer_to_hf save_hf_to_composer $HF_MODEL_NAME $OUTPUT_PATH

Additionally, you can use the following utility function to test the equivalence between the Hugging Face model and the converted Composer model:

MODEL_SIZE=7B
python3 -m llmshearing.utils.test_composer_hf_eq $HF_MODEL_NAME $OUTPUT_PATH $MODEL_SIZE

These functions exclusively work for LLaMA/LLaMA2 models. However, it should be straightforward to adapt them for use with other models such as Mistral-7B.

Sample Scripts for Pruning and Continued Pre-training

For pruning, you can reference an example script located in llmshearing/scripts/pruning.sh. In this script, you will need to make adjustments to incorporate [data configurations](), [basic training configurations](), [pruning configurations]() and [dynamic batch loading configurations]().

Due to the relatively higher computational cost of pruning compared to continued pre-training, we halt training with the pruning objective after a specific number of steps (typically 3200 steps in all our experiments). Subsequently, we proceed with further pre-training of the pruned model. To ensure compatibility, it is necessary to convert the state dictionary keys of the model to align with a standard target model structure. Detailed instructions for this conversion can be found at Convert Pruned Model.

After completing the model conversion, you can continue with the pre-training of the pruned model. The process is similar to pre-train a standard model. To do this, you can refer to an example script located at llmshearing/scripts/continue_pretraining.sh. In this script, the pruning configurations are eliminated.

After training the model, you can use the conversion script to convert the composer model into a transformers model. Please refer to Section Convert Composer Model to Huggingface Model for more details.

Convert Pruned Model

Following the completion of training using llmshearing/scripts/pruning.sh, the saved models consist of the entire parameters of the source model, accompanied by a set of masks. We then act upon the masking variables by 1) removing the substructures where the masking variables are near $0$, 2) subsuming the masking variables into the model parameters by matrix-vector multiplcaition, and it result in a more compact model. Simultaneously, it becomes necessary to rename the weight keys so that they can be seamlessly loaded into a target model architecture, ensuring that the layer names are all consecutive.

MODEL_PATH=$MODEL_DIR/latest-rank0.pt
python3 -m llmshearing.utils.post_pruning_processing prune_and_save_model $MODEL_PATH

The pruned model will be saved in $(dirname $MODEL_PATH)/pruned-latest-rank0.pt.

Convert Composer Model to Huggingface Model

After training, if you'd like to use use huggingface for inference or fine-tuning, you may opt to transform your composer model into a Hugging Face model using the llmshearing/scripts/composer_to_hf.py script. Here's an example of how to use the script:

MODEL_PATH=$MODEL_DIR/latest-rank0.pt
OUTPUT_PATH=$MODEL_DIR/hf-latest_rank0
MODEL_CLASS=LlamaForCausalLM
HIDDEN_SIZE=2048
NUM_ATTENTION_HEADS=16
NUM_HIDDEN_LAYERS=24
INTERMEDIATE_SIZE=5504
MODEL_NAME=Sheared-Llama-1.3B

python3 -m llmshearing.utils.composer_to_hf save_composer_to_hf $MODEL_PATH $OUTPUT_PATH \
        model_class=${MODEL_CLASS} \
        hidden_size=${HIDDEN_SIZE} \
        num_attention_heads=${NUM_ATTENTION_HEADS} \
        num_hidden_layers=${NUM_HIDDEN_LAYERS} \
        intermediate_size=${INTERMEDIATE_SIZE} \
        num_key_value_heads=${NUM_ATTENTION_HEADS} \
        _name_or_path=${MODEL_NAME}

Please be aware that the parameter names mentioned here are tailored to Llama2's Hugging Face configurations and may differ when dealing with other model types.

Training Configurations

In this section, we provide an in-depth guide on configuring parameters within YAML configuration files for training. These configurations encompass several key aspects, including data setup, fundamental training settings, pruning settings, and dynamic data loading configurations.

Data configurations

Basic training configurations

The basic training configurations largely follow the original Composer package. For comprehensive details on these configurations, please refer to Composer's official documentation. Here are some key training parameters to take note of:

Due to computational constraints, an exhaustive hyperparameter search was not conducted, and there may exist better hyper-parameters for improved performance.

Pruning configurations

The pruning process allows pruning a source model to a specific target shape, and the script includes essential parameters such as:

The pruning-specific arguments are all grouped under model.l0_module:

These parameters allow you to configure and control the pruning process according to your specific requirements.

Dynamic batch loading configurations

We extend Steaming's StreamingDataset in datasets/streaming_dataset.py to support loading data dynamically. The parameters for configuring dynamic batch loading are primarily defined within the DynamicLoadingCallback. Most of the following configurations can be specified in a YAML configuration file under the callbacks.data_loading section. Here's an explanation of each parameter:

The code is designed to exclusively accommodate local data and does not support remote streaming data. Additionally, it currently only functions with a single worker for the dataloader and does not offer prefetch support. In our testing, this restriction has does not incur any additional compute overhead.

Throughput

Here is the throughout of running the pruning and continued pretraining step with A100 80GB GPUs. The throughput is quantified in terms of tokens processed per second. Please refer to the standard throughput of llm-foundry.

GPUs Throughput per Device Throughput
Pruning 7B 8 1844 14750
Pre-training 3B 16 4957 79306
Pre-training 1.3B 16 8684 138945

Future Work

Source models: While large models are undoubtedly powerful and have the potential to become stronger in the near future, we believe that small-scale models (those with fewer than 7 billion parameters) have untapped potential. However, there is little effort dedicated to making small models stronger, and our work pushes towards this goal. A natural extension of this work is to extend the codebase to prune

To adapt the codebase to other models, one key component is to make sure that running the model with masks is equivalent to running the pruned model. We use llmshearing/utils/test_pruning.py to run such tests to ensure the correctness of the function prune_params in model files.

Data Sources: Please keep in mind that the performance of the resulting model is contingent not only on the pruning algorithm and the base model but also on the quality of the data. In our experiments, we mainly worked the RedPajama v1 data. However, here are some additional resources that could be considered for inclusion:

Bugs or Questions?

If you have any questions related to the code or the paper, feel free to email Mengzhou (mengzhou@princeton.edu). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!

Citation

Please cite our paper if you find the repo helpful in your work:

@article{xia2023sheared,
  title={Sheared llama: Accelerating language model pre-training via structured pruning},
  author={Xia, Mengzhou and Gao, Tianyu and Zeng, Zhiyuan and Chen, Danqi},
  journal={arXiv preprint arXiv:2310.06694},
  year={2023}
}