This is the source code to reproduce the experiments for "Localizing Task Information for Improved Model Merging and Compression" by Ke Wang, Nikolaos Dimitriadis, Guillermo Ortiz-Jimenez, Francois Fleuret, and Pascal Frossard.
Our paper identifies that the task-specific knowledge is preserved after mering, and proposed a method named TALL mask to localize them. Based on TALL mask, we proposed: 1) a compression scheme which utilizes TALL mask to recover single-task fine-tuned performance for each task 2) a merging algorithm which removes catastrophic and selfish weights to improve model merging performance
You can also check more information on the project website.
To run the code, please install all its dependencies:
conda env create
conda activate tall-masks
We provide the checkpoints, as well as the generated task-specific masks we used in the paper in this link. Alternatively, you can download the checkpoints and masks by running the following script:
# model options --model {ViT-B-32,ViT-L-14}
# kind options --kind {checkpoints,tall_masks}
# use python download_checkpoints.py --help for more information
python download_checkpoints.py --model='ViT-B-32' --kind=checkpoints
The script downloads all the checkpoints for one model corresponding to 40 files (finetuned checkpoint and classification head for 20 tasks). The script used the gdown
package to download the files. If you encounter any issues, please refer to the gdown documentation. A common issue is that the download quota is exceeded, in which case you can download the files manually from the Google Drive folder or modify your local cookies file as described in the gdown documentation.
Alternatively, the checkpoints can be downloaded from the HuggingFace repo nik-dim/tall_masks
. See the snapshot_download documentation
for more details.
from huggingface_hub import snapshot_download
# download the ViT-B-32 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*32*")
# download the ViT-B-16 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*16*")
# download the ViT-L-14 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*14*")
# download everything
snapshot_download(repo_id="nik-dim/tall_masks")
Most datasets being used should be downloaded automatically with torchvision or huggingface. For the datasets requiring manual preparation, please follow the instructions in this issue. Depending on the torchvision version, some issues might arise when downloading specific datasets like here or here. In this case, using a different torchvision version might solve the issue.
Below gives an example of pseudo-code to use TALL mask to localize the information in multi-task vector to reconstruct the individual checkpoints.
To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint:
from task_vectors import TaskVector
task_vector_A = TaskVector(pretrained_checkpoint, finetuned_checkpoint_A)
Create a multi-task vector:
multi_task_vector = task_vector_A + task_vector_B + task_vector_C
Construct tall mask:
tall_mask_A = task_vector_A.abs() > (multi_task_vector - task_vector_A).abs() * lambda
Reconstruct fine-tuned model with tall mask:
# the reconstructed finetuned_checkpoint_A has near the same performance as original finetuned_checkpoint_A
reconstructed_finetuned_checkpoint_A = pretrained_checkpoint + multi_task_vector * tall_mask_A
The script finetune.py
can be used to reproduce the training protocol we used to fine-tune our models on all our downstream tasks.
# Finetune on 2 GPUs
python finetune.py --model=ViT-B-32 --world-size=2
Evaluation is performed with Hydra, please modify model_location
and data_location
in config/config.yaml
before evaluation.
# Evaluate with Task Arithmetic
python main.py model=ViT-B-32 method="sum"
# Evaluate with Ties-merging
python main.py model=ViT-B-32 method="ties" method.k=20
# Evaluate with Tall mask + Task Arithmetic (load tall masks from storage)
python main.py model=ViT-B-32 method="tall_mask" method.load_mask=True
# Evaluate with Tall mask + Task Arithmetic (construct tall masks from scratch)
python main.py model=ViT-B-32 method="tall_mask"
# Evaluate with Tall mask + Ties-merging (load tall masks from storage)
python main.py model=ViT-B-32 method="tall_mask" method.use_ties=True method.load_mask=True
# Evaluate with Tall mask + Ties-merging (construct tall masks from scratch)
python main.py model=ViT-B-32 method="tall_mask" method.use_ties=True
# Evaluate with Consensus Task Arithmetic
python main.py model=ViT-B-32 method="consensus" method.prun_thre_k=2
# Evaluate with Consensus Ties-merging
python main.py model=ViT-B-32 method="consensus" method.prun_thre_k=2 method.use_ties=True
Note that you can set different number of tasks by setting num_tasks
. Then, the first num_tasks
are going to be selected from the list defined in src/utils/variables_and_paths.py
. Alternatively, you can directly specify the tasks as a list of strings (e.g. DATASETS=[MNIST,Cars]
). The results of the papers can be retrived by setting num_tasks
to 8, 14 and 20 for the corresponding experiments.
You can evaluate the performance of the fine-tuned weights on each single task by running
# Evaluate pre-trained models.
python eval_single_task.py --model=ViT-B-32 --finetuning-mode=none
# Evaluate non-linearly fine-tuned models.
python eval_single_task.py --model=ViT-B-32 --finetuning-mode=standard
The results are saved in the results/
folder.
If you find this code useful, please cite the following paper:
@inproceedings{wang2024localizing,
title={Localizing Task Information for Improved Model Merging and Compression},
author={Wang, Ke and
Dimitriadis, Nikolaos and
Ortiz{-}Jim{\'{e}}nez, Guillermo and
Fleuret, Fran\c{c}ois and
Frossard, Pascal},
booktitle={International Conference on Machine Learning},
year={2024}
}