snimu / rebasin

Apply methods described in "Git Re-basin"-paper [1] to arbitrary models --- [1] Ainsworth et al. (https://arxiv.org/abs/2209.04836)
MIT License
13 stars 0 forks source link
deep-learning git-re-basin git-rebasin pytorch rebasin

rebasin

PyPI Version Wheel Python 3.8+ License

An implementation of methods described in "Git Re-basin"-paper by Ainsworth et al.

Can be applied to arbitrary models, without modification.

(Well, almost arbitrary models, see Limitations).


Table of Contents

Installation

Requirements should be automatically installed, but one of them is graphviz, which you might have to install per apt / brew / ... on your device.

The following install instructions are taken directly from torchview's installation instructions.

Debian-based Linux distro (e.g. Ubuntu):

apt-get install graphviz

Windows:

choco install graphviz

macOS

brew install graphviz

see more details here.

Then, install rebasin via pip:

pip install rebasin

Usage

Currently, only weight-matching is implemented as a method for rebasing, and only a simplified form of linear interpolation is implemented.

The following is a minimal example. For now, the documentation lives in the docstrings, though I intend to create a proper one. PermutationCoordinateDescent and interpolation.LerpSimple are the main classes, beside MergeMany (see below).

from rebasin import PermutationCoordinateDescent
from rebasin import interpolation

model_a, model_b, train_dl= ...
input_data = next(iter(train_dl))[0]

# Rebasin
pcd = PermutationCoordinateDescent(model_a, model_b, input_data)  # weight-matching
pcd.rebasin()  # Rebasin model_b towards model_a. Automatically updates model_b

# Interpolate
lerp = interpolation.LerpSimple(
    models=[model_a, model_b],
    devices=["cuda:0", "cuda:1"],  # Optional, defaults to cpu
    device_interp="cuda:2",  # Optional, defaults to cpu
    savedir="/path/to/save/interpolation"  # Optional, save all interpolated models
)
lerp.interpolate(steps=99)  # Interpolate 99 models between model_a and model_b

The MergeMany-algorithm is also implemented (though there will be interface-changes regarding the devices in the future):

from rebasin import MergeMany
from torch import nn

class ExampleModel(nn.Module):
    ...

model_a, model_b, model_c = ExampleModel(), ExampleModel(), ExampleModel()
train_dl = ...

# Merge
merge = MergeMany(
    models=[model_a, model_b, model_c],
    working_model=ExampleModel(),
    input_data=next(iter(train_dl))[0],
)
merged_model = merge.run()
# The merged model is also accessible through merge.working_model,
#   but only after merge.run() has been called.

Terminology

In this document, I will use the following terminology:

Limitations

Only some methods are implemented

For rebasin, only weight-matching is implemented via rebasin.PermutationCoordinateDescent.

For interpolation, only a simplified method of linear interpolation is implemented via rebasin.interpolation.LerpSimple.

Limitations of the PermutationCoordinateDescent-class

The PermutationCoordinateDescent-class only permutes some Modules. Most modules should work, but others may behave unexpectedly. In this case, you need to add the module to rebasin/modules.py; make sure it is included in the initialize_module-function (preferably by putting it into the SPECIAL_MODULES-dict).

Additionally, the PermutationCoordinateDescent-class only works with nn.Modules, not functions. There is a requirement to have the permuted model produce the same output as the un-permuted Module, which is a pretty tight constraint. In some models, it isn't a problem at all, but especially in models with lots of short residual blocks, it may (but doesn't have to) be a problem. Where it is a problem, few to no parameters get permuted, which defeats the purpose of rebasin.

For example, @tysam-code's hlb-gpt, a small but fast language model implementation, isn't permuted at all. Vision transformers like torchvision.models.vit_b_16 have only very few permutations applied to them. In general, transformer models don't work well, because they reshape the input-tensor, and directly follow that up with residual blocks. This means that almost nothing of the model can be permuted (a single Linear layer between the reshaping and the first residual block would fix that, but this isn't usually done...).

On the other hand, CNNs usually work very well.

If you are unsure, you can always print the model-graph! To do so, write:

from rebasin import PermutationCoordinateDescent

pcd = PermutationCoordinateDescent(...)
print(pcd.pinit.model_graph)  # pinit stands for "PermutationInitialization"

Results

For the full results, see rebasin-results (I don't want to upload a bunch of images to this repo, so the results are in their own repo).

The clearest results were produces on hlb-CIFAR10. For results on that model, see here.

Here is a little taste of the results for that model:

hlb-CIFAR10: losses and accuracies of the model

While PermutationCoordinateDescent doesn't fully eliminate the loss-barrier, it does reduce it significantly, and, surprisingly, even moreso for the accuracy-barrier.

You can also find results for the MergeMany-algorithm there.

Acknowledgements

Git Re-Basin:

Ainsworth, Samuel K., Jonathan Hayase, and Siddhartha Srinivasa. 
"Git re-basin: Merging models modulo permutation symmetries." 
arXiv preprint arXiv:2209.04836 (2022).

Link: https://arxiv.org/abs/2209.04836 (accessed on April 9th, 2023)

ImageNet:

I've used the ImageNet Data from the 2012 ILSVRC competition to evaluate the algorithms from rebasin on the torchvision.models.

Olga Russakovsky*, Jia Deng*, Hao Su, Jonathan Krause, Sanjeev Satheesh, 
Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, 
Alexander C. Berg and Li Fei-Fei. (* = equal contribution) 
ImageNet Large Scale Visual Recognition Challenge. arXiv:1409.0575, 2014

Paper (link) (Accessed on April 12th, 2023)

Torchvision models

For testing, I've used the torchvision models (v.015), of course (or I will):

https://pytorch.org/vision/0.15/models.html

HLB-CIFAR10 For testing, I forked hlb-CIFAR10 by @tysam-code:

authors:
- family-names: "Balsam"
  given-names: "Tysam&"
title: "hlb-CIFAR10"
version: 0.4.0
date-released: 2023-02-12
url: "https://github.com/tysam-code/hlb-CIFAR10"

HLB-GPT For testing, I also used hlb-gpt by @tysam-code:

authors:
  - family-names: "Balsam"
    given-names: "Tysam&"
title: "hlb-gpt"
version: 0.0.0
date-released: 2023-03-05
url: "https://github.com/tysam-code/hlb-gpt"

Other

My code took inspiration from the following sources:

I used the amazing library torchview to visualize the models: