StupidTrees / SplitLLM

Split Learning Simulation Framework for LLMs
Apache License 2.0
11 stars 3 forks source link
large-language-models privacy-attacks split-learning

SplitLLM:
Split Learning Simulation Framework for LLMs

[![License](https://img.shields.io/github/license/StupidTrees/SplitLLM)](https://opensource.org/licenses/Apache-2.0) [![arXiv](https://img.shields.io/badge/arXiv-2409.00960-red.svg)](https://arxiv.org/abs/2409.00960)

SplitLLM is a Split Learning simulation framework designed for Large Language Models (LLMs).

Environment Setup

conda create -n sfl python=3.11
conda activate sfl
conda install pytorch==2.3.0 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt

Model Download

  1. Go to sfl/config and modify dataset_cache_dir,model_download_dir,model_cache_dir to your own path
  2. Run the following commands to download models, some of them may require a qualified huggingface token
    cd experiments/script
    python model_download.py --repo_id meta-llama/Llama-2-7b-chat-hf
    python model_download.py --repo_id gpt2-large
    python model_download.py --repo_id THUDM/chatglm3-6b
    ...
    #python model_download.py --repo_id FacebookAI/roberta-large
    #python model_download.py --repo_id google-bert/bert-large-uncased
    #python model_download.py --repo_id google/flan-t5-base
    #python model_download.py --repo_id google/flan-ul2-base
    #python model_download.py --repo_id meta-llama/Meta-Llama-3-8B 
    #python model_download.py --repo_id lucyknada/microsoft_WizardLM-2-7B
    #python model_download.py --repo_id lmsys/vicuna-7b-v1.5
    #python model_download.py --repo_id tiiuae/falcon-7b-instruct
    #python model_download.py --repo_id Salesforce/codegen25-7b-instruct_P
    #python model_download.py --repo_id EleutherAI/gpt-j-6b
    #python model_download.py --repo_id google/flan-ul2
    #python model_download.py --repo_id google/vit-large-patch16-224
    #python model_download.py --repo_id bigscience/bloomz-560m
    #python model_download.py --repo_id state-spaces/mamba-1.4b-hf
    • (Optional) Use parallelized implementation of Mamba causal-conv1d & mamba-ssm using:
      pip install causal-conv1d>=1.2.0
      pip install mamba-ssm

      Note that the default implementation of Mamba is sequential.

Download Datasets

cd $dataset_cache_dir
git clone https://huggingface.co/datasets/wikitext.git
git clone https://huggingface.co/datasets/piqa.git
git clone https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K.git
git clone https://huggingface.co/datasets/knkarthick/dialogsum.git
git clone https://huggingface.co/datasets/gsm8k.git
git clone https://huggingface.co/datasets/imdb.git
git clone https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese.git
git clone https://huggingface.co/datasets/frgfm/imagewoof.git
git clone https://huggingface.co/datasets/SetFit/qnli.git
git clone https://huggingface.co/datasets/linxinyuan/cola.git

Run the Experiment Demo

Note that the script requires wandb to be installed and configured.

Split Learning Simulation

In SL, a model is divided into three parts:

Bottom Layers Trunk Layers(Adapters) Top Layers

where Bottom-Layers and Top-Layers are input and output end of the model, and Trunk-Layers are the middle part of the model.

To simulate Split Federated Learning (SFL), we do not employ the approach of physically splitting the model in code implementation. Instead, we independently maintain different parts of the model's parameters. We simulate Client training in a serial manner, without actual Client parallelism.

The simulation process is as follows:

  1. At the start of a round of federated learning, select Clients 0, 1, and 2.
  2. Load the model parameters of Client 0 (including bottom and top layers, and its corresponding trunk on the Server) from disk into the GPU model.
  3. Client 0 performs local training, updating all model parameters, following the process consistent with centralized learning.
  4. Once Client 0 completes training, save the model parameters to disk.
  5. Load the model parameters of Client 1 from disk into the GPU model.
  6. ...
  7. At the end of the federated learning round, aggregate the trunk parameters corresponding to all clients on disk to obtain the average trunk. Then, update the trunk parameters of all clients to the average trunk.
  8. Repeat steps 1-7.