arcee-ai / DAM

30 stars 4 forks source link

Differentiable Adaptive Merging (DAM)

Project Figure

Differentiable Adaptive Merging (DAM) automates the merging of multiple LLMs with unique capabilities, optimizing the balance between models for improved data efficiency and reduced computational costs. DAM outperforms traditional and evolutionary methods, offering a scalable solution for versatile AI systems. Extensive experiments validate DAM's superiority across various model merging scenarios.

Steps to Run the Workflow

This repository contains the implementation for running the merging coefficient tuning process.

1. Create the Merged Model

First, create the merged model by running the merge.py script found in the dam folder. The resulting merged model will contain untrained coefficients.

In this step, we assign a trainable coefficient for each column of each model's layer norms, embedding layers, and linear layers as specified by the user. These coefficients will be optimized during the training process to achieve the best integration of the models.

Command:

python dam/merge.py mistralai/Mistral-7B-v0.1 augmxnt/shisa-gamma-7b-v1 WizardLM/WizardMath-7B-V1.1 arcee-train/Abel-7B-002-truncated-embeds --device cuda --output_path ./merged_model   --repo_id arcee-train/[prefix]-untrained-merge

Arguments:

2. Prepare the Dataset

To prepare the dataset, navigate to the dam/data folder and run create_merge_dataset.py. This script will create a composite dataset with examples from the data used to train the models we are going to merge, apply their templates, and tokenize the data. Optionally, it can compute and save the top-K logits for other models, which will be used later during training. Additionally, it is optional to compute the logits beforehand; we can also compute them on-the-fly during training.

Command:

python dam/data/create_merge_dataset.py  --dataset_names "p1atdev/ichikara-instruction:20231115-1,microsoft/orca-math-word-problems-200k,meta-math/MetaMathQA"   --model_ids "augmxnt/shisa-gamma-7b-v1,WizardLM/WizardMath-7B-V1.1,arcee-train/Abel-7B-002-truncated-embeds" --base_model_name mistralai/Mistral-7B-v0.1 --cache_dir /home/ec2-user/.cache/huggingface --compute_logits True --dataset_id arcee-train/[prefix]-combined-dataset --example_count 1729 --max_length 2048 --add_top_k_logits  False

Arguments:

3. Run the Training

In this step, navigate to the dam/train_dam.py script. The purpose of this step is to train the coefficients. At the end of the training process, the model is merged into the base model architecture with the optimized coefficients. Additionally, this code has the capability to work with multiple GPUs.

Manual configurations are available at the top of the train_dam.py script.

Command:

python dam/train_dam.py --learning_rate 1e-3 --lambda_coef_similarity 0.01 --generate_logits_on_fly True --untrained_merged_model_name arcee-train/[your-model] --combined_hf_dataset_dir arcee-train/[prefix]-combined-dataset --cache_dir /home/ec2-user/.cache/huggingface --base_model_name mistralai/Mistral-7B-v0.1 --use_wandb True

Arguments:

Citation

We now have a paper you can cite for the DAM Method:

@article{gauthier2024merging,
  title={Merging in a Bottle: Differentiable Adaptive Merging (DAM) and the Path from Averaging to Automation},
  author={Gauthier-Caron, Thomas and Siriwardhana, Shamane and Stein, Elliot and Ehghaghi, Malikeh and Goddard, Charles and McQuade, Mark and Solawetz, Jacob and Labonne, Maxime},
  journal={arXiv preprint arXiv:2410.08371},
  year={2024}
}