Official implementation of Equivariant Deep Weight Space Alignment, ICML 2024
conda create --name deep-align python=3.9
conda activate deep-align
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
git clone https://github.com/AvivNavon/deep-align.git
cd deep-align
pip install -e .
An introduction notebook for MNIST MLPs alignment with DEEP-ALIGN:
We have released the following datasets:
To run the MLP experiments, first download the data:
mkdir datasets
wget "https://www.dropbox.com/s/sv85hrjswaspok4/mnist_classifiers.zip"
unzip -q mnist_classifiers.zip -d datasets
Split data:
python experiments/utils/data/generate_splits.py --data-root datasets/mnist_classifiers --save-path datasets/splits.json
Next, use the following command:
cd experiments/mlp_image_classifier
python trainer.py --data-path <path-to-splits-splits.json> --image-data-path=<path-to-mnist-dataset> --no-wandb
Or log runs to wandb:
python trainer.py --data-path <path-to-splits-splits.json> --image-data-path=<path-to-mnist-dataset> --wandb-entity=<wandb-entity> --wandb-project=<wandb-project>
We utilized code provided by the following repositories:
If you find this code useful in your research, please consider citing:
@article{navon2023equivariant,
title={Equivariant Deep Weight Space Alignment},
author={Navon, Aviv and Shamsian, Aviv and Fetaya, Ethan and Chechik, Gal and Dym, Nadav and Maron, Haggai},
journal={arXiv preprint arXiv:2310.13397},
year={2023}
}