gstoica27 / ZipIt

A framework for merging models solving different tasks with different initializations into one multi-task model without any additional training
MIT License
286 stars 25 forks source link

merging merged models #11

Closed nsaadati closed 11 months ago

nsaadati commented 1 year ago

Hey thanks a lot for your well writen paper I'm trying to figure out what happens when I merge models together, specifically when I merge merged models. Let's say I have four models, each one being a resnet20*8 model trained on 20 classes from CIFAR100. What I want to do is merge two models at a time, and then merge the resulting merged models together. Can you help me understand how to do that? By the way, I encountered an error when I tried merging the merged models. Also, I tweaked your code to merge five models at once, but I'm not quite sure how it's working. Does it go through all the models and compute the matrix, or does it merge two models and then merge the third one with the combined model? Your help with this would be greatly appreciated. Also this is the error that I get for mergeing merged models:

Evaluating Pairs...: 0%| | 0/4 [00:00<?, ?it/s]Files already downloaded and verified 50000it [00:00, 896912.57it/s] 50000it [00:00, 883643.87it/s] Files already downloaded and verified 10000it [00:00, 906190.77it/s] 10000it [00:00, 912003.48it/s] Preparing Models: 0%| | 0/2 [00:04<?, ?it/s] Evaluating Pairs...: 0%| | 0/4 [00:06<?, ?it/s] Traceback (most recent call last): File "/home/exouser/.conda/envs/cf/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/exouser/.conda/envs/cf/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/exouser/SZip/ZipIt/evaluation_scripts/zipit_concept_merging.py", line 143, in run_node_experiment( File "/home/exouser/SZip/ZipIt/evaluation_scripts/zipit_concept_merging.py", line 64, in run_node_experiment config = prepare_experiment_config(raw_config) File "/home/exouser/SZip/ZipIt/utils.py", line 822, in prepare_experiment_config 'models': prepare_models(config['model'], device=config['device']), File "/home/exouser/SZip/ZipIt/utils.py", line 779, in prepare_models return prepare_resnets(config, device) File "/home/exouser/SZip/ZipIt/utils.py", line 722, in prepare_resnets base_model.load_state_dict(base_sd) File "/home/exouser/.conda/envs/cf/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.shortcut.0.weight", "layer2.0.shortcut.1.weight", "layer2.0.shortcut.1.bias", "layer2.0.shortcut.1.running_mean", "layer2.0.shortcut.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.shortcut.0.weight", "layer3.0.shortcut.1.weight", "layer3.0.shortcut.1.bias", "layer3.0.shortcut.1.running_mean", "layer3.0.shortcut.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "linear.weight", "linear.bias".

nsaadati commented 1 year ago

this the rest of error message RuntimeError: Error(s) in loading state_dict for ResNet: Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight", "layer1.0.bn1.bias", "layer1.0.bn1.running_mean", "layer1.0.bn1.running_var", "layer1.0.conv2.weight", "layer1.0.bn2.weight", "layer1.0.bn2.bias", "layer1.0.bn2.running_mean", "layer1.0.bn2.running_var", "layer1.1.conv1.weight", "layer1.1.bn1.weight", "layer1.1.bn1.bias", "layer1.1.bn1.running_mean", "layer1.1.bn1.running_var", "layer1.1.conv2.weight", "layer1.1.bn2.weight", "layer1.1.bn2.bias", "layer1.1.bn2.running_mean", "layer1.1.bn2.running_var", "layer1.2.conv1.weight", "layer1.2.bn1.weight", "layer1.2.bn1.bias", "layer1.2.bn1.running_mean", "layer1.2.bn1.running_var", "layer1.2.conv2.weight", "layer1.2.bn2.weight", "layer1.2.bn2.bias", "layer1.2.bn2.running_mean", "layer1.2.bn2.running_var", "layer2.0.conv1.weight", "layer2.0.bn1.weight", "layer2.0.bn1.bias", "layer2.0.bn1.running_mean", "layer2.0.bn1.running_var", "layer2.0.conv2.weight", "layer2.0.bn2.weight", "layer2.0.bn2.bias", "layer2.0.bn2.running_mean", "layer2.0.bn2.running_var", "layer2.0.shortcut.0.weight", "layer2.0.shortcut.1.weight", "layer2.0.shortcut.1.bias", "layer2.0.shortcut.1.running_mean", "layer2.0.shortcut.1.running_var", "layer2.1.conv1.weight", "layer2.1.bn1.weight", "layer2.1.bn1.bias", "layer2.1.bn1.running_mean", "layer2.1.bn1.running_var", "layer2.1.conv2.weight", "layer2.1.bn2.weight", "layer2.1.bn2.bias", "layer2.1.bn2.running_mean", "layer2.1.bn2.running_var", "layer2.2.conv1.weight", "layer2.2.bn1.weight", "layer2.2.bn1.bias", "layer2.2.bn1.running_mean", "layer2.2.bn1.running_var", "layer2.2.conv2.weight", "layer2.2.bn2.weight", "layer2.2.bn2.bias", "layer2.2.bn2.running_mean", "layer2.2.bn2.running_var", "layer3.0.conv1.weight", "layer3.0.bn1.weight", "layer3.0.bn1.bias", "layer3.0.bn1.running_mean", "layer3.0.bn1.running_var", "layer3.0.conv2.weight", "layer3.0.bn2.weight", "layer3.0.bn2.bias", "layer3.0.bn2.running_mean", "layer3.0.bn2.running_var", "layer3.0.shortcut.0.weight", "layer3.0.shortcut.1.weight", "layer3.0.shortcut.1.bias", "layer3.0.shortcut.1.running_mean", "layer3.0.shortcut.1.running_var", "layer3.1.conv1.weight", "layer3.1.bn1.weight", "layer3.1.bn1.bias", "layer3.1.bn1.running_mean", "layer3.1.bn1.running_var", "layer3.1.conv2.weight", "layer3.1.bn2.weight", "layer3.1.bn2.bias", "layer3.1.bn2.running_mean", "layer3.1.bn2.running_var", "layer3.2.conv1.weight", "layer3.2.bn1.weight", "layer3.2.bn1.bias", "layer3.2.bn1.running_mean", "layer3.2.bn1.running_var", "layer3.2.conv2.weight", "layer3.2.bn2.weight", "layer3.2.bn2.bias", "layer3.2.bn2.running_mean", "layer3.2.bn2.running_var", "linear.weight", "linear.bias". Unexpected key(s) in state_dict: "head_models.0.conv1.weight", "head_models.0.bn1.weight", "head_models.0.bn1.bias", "head_models.0.bn1.running_mean", "head_models.0.bn1.running_var", "head_models.0.bn1.num_batches_tracked", "head_models.0.layer1.0.conv1.weight", "head_models.0.layer1.0.bn1.weight", "head_models.0.layer1.0.bn1.bias", "head_models.0.layer1.0.bn1.running_mean", "head_models.0.layer1.0.bn1.running_var", "head_models.0.layer1.0.bn1.num_batches_tracked", "head_models.0.layer1.0.conv2.weight", "head_models.0.layer1.0.bn2.weight", "head_models.0.layer1.0.bn2.bias", "head_models.0.layer1.0.bn2.running_mean", "head_models.0.layer1.0.bn2.running_var", "head_models.0.layer1.0.bn2.num_batches_tracked", "head_models.0.layer1.1.conv1.weight", "head_models.0.layer1.1.bn1.weight", "head_models.0.layer1.1.bn1.bias", "head_models.0.layer1.1.bn1.running_mean", "head_models.0.layer1.1.bn1.running_var", "head_models.0.layer1.1.bn1.num_batches_tracked", "head_models.0.layer1.1.conv2.weight", "head_models.0.layer1.1.bn2.weight", "head_models.0.layer1.1.bn2.bias", "head_models.0.layer1.1.bn2.running_mean", "head_models.0.layer1.1.bn2.running_var", "head_models.0.layer1.1.bn2.num_batches_tracked", "head_models.0.layer1.2.conv1.weight", "head_models.0.layer1.2.bn1.weight", "head_models.0.layer1.2.bn1.bias", "head_models.0.layer1.2.bn1.running_mean", "head_models.0.layer1.2.bn1.running_var", "head_models.0.layer1.2.bn1.num_batches_tracked", "head_models.0.layer1.2.conv2.weight", "head_models.0.layer1.2.bn2.weight", "head_models.0.layer1.2.bn2.bias", "head_models.0.layer1.2.bn2.running_mean", "head_models.0.layer1.2.bn2.running_var", "head_models.0.layer1.2.bn2.num_batches_tracked", "head_models.0.layer2.0.conv1.weight", "head_models.0.layer2.0.bn1.weight", "head_models.0.layer2.0.bn1.bias", "head_models.0.layer2.0.bn1.running_mean", "head_models.0.layer2.0.bn1.running_var", "head_models.0.layer2.0.bn1.num_batches_tracked", "head_models.0.layer2.0.conv2.weight", "head_models.0.layer2.0.bn2.weight", "head_models.0.layer2.0.bn2.bias", "head_models.0.layer2.0.bn2.running_mean", "head_models.0.layer2.0.bn2.running_var", "head_models.0.layer2.0.bn2.num_batches_tracked", "head_models.0.layer2.0.shortcut.0.weight", "head_models.0.layer2.0.shortcut.1.weight", "head_models.0.layer2.0.shortcut.1.bias", "head_models.0.layer2.0.shortcut.1.running_mean", "head_models.0.layer2.0.shortcut.1.running_var", "head_models.0.layer2.0.shortcut.1.num_batches_tracked", "head_models.0.layer2.1.conv1.weight", "head_models.0.layer2.1.bn1.weight", "head_models.0.layer2.1.bn1.bias", "head_models.0.layer2.1.bn1.running_mean", "head_models.0.layer2.1.bn1.running_var", "head_models.0.layer2.1.bn1.num_batches_tracked", "head_models.0.layer2.1.conv2.weight", "head_models.0.layer2.1.bn2.weight", "head_models.0.layer2.1.bn2.bias", "head_models.0.layer2.1.bn2.running_mean", "head_models.0.layer2.1.bn2.running_var", "head_models.0.layer2.1.bn2.num_batches_tracked", "head_models.0.layer2.2.conv1.weight", "head_models.0.layer2.2.bn1.weight", "head_models.0.layer2.2.bn1.bias", "head_models.0.layer2.2.bn1.running_mean", "head_models.0.layer2.2.bn1.running_var", "head_models.0.layer2.2.bn1.num_batches_tracked", "head_models.0.layer2.2.conv2.weight", "head_models.0.layer2.2.bn2.weight", "head_models.0.layer2.2.bn2.bias", "head_models.0.layer2.2.bn2.running_mean", "head_models.0.layer2.2.bn2.running_var", "head_models.0.layer2.2.bn2.num_batches_tracked", "head_models.0.layer3.0.conv1.weight", "head_models.0.layer3.0.bn1.weight", "head_models.0.layer3.0.bn1.bias", "head_models.0.layer3.0.bn1.running_mean", "head_models.0.layer3.0.bn1.running_var", "head_models.0.layer3.0.bn1.num_batches_tracked", "head_models.0.layer3.0.conv2.weight", "head_models.0.layer3.0.bn2.weight", "head_models.0.layer3.0.bn2.bias", "head_models.0.layer3.0.bn2.running_mean", "head_models.0.layer3.0.bn2.running_var", "head_models.0.layer3.0.bn2.num_batches_tracked", "head_models.0.layer3.0.shortcut.0.weight", "head_models.0.layer3.0.shortcut.1.weight", "head_models.0.layer3.0.shortcut.1.bias", "head_models.0.layer3.0.shortcut.1.running_mean", "head_models.0.layer3.0.shortcut.1.running_var", "head_models.0.layer3.0.shortcut.1.num_batches_tracked", "head_models.0.layer3.1.conv1.weight", "head_models.0.layer3.1.bn1.weight", "head_models.0.layer3.1.bn1.bias", "head_models.0.layer3.1.bn1.running_mean", "head_models.0.layer3.1.bn1.running_var", "head_models.0.layer3.1.bn1.num_batches_tracked", "head_models.0.layer3.1.conv2.weight", "head_models.0.layer3.1.bn2.weight", "head_models.0.layer3.1.bn2.bias", "head_models.0.layer3.1.bn2.running_mean", "head_models.0.layer3.1.bn2.running_var", "head_models.0.layer3.1.bn2.num_batches_tracked", "head_models.0.layer3.2.conv1.weight", "head_models.0.layer3.2.bn1.weight", "head_models.0.layer3.2.bn1.bias", "head_models.0.layer3.2.bn1.running_mean", "head_models.0.layer3.2.bn1.running_var", "head_models.0.layer3.2.bn1.num_batches_tracked", "head_models.0.layer3.2.conv2.weight", "head_models.0.layer3.2.bn2.weight", "head_models.0.layer3.2.bn2.bias", "head_models.0.layer3.2.bn2.running_mean", "head_models.0.layer3.2.bn2.running_var", "head_models.0.layer3.2.bn2.num_batches_tracked", "head_models.0.linear.weight", "head_models.0.linear.bias", "head_models.1.conv1.weight", "head_models.1.bn1.weight", "head_models.1.bn1.bias", "head_models.1.bn1.running_mean", "head_models.1.bn1.running_var", "head_models.1.bn1.num_batches_tracked", "head_models.1.layer1.0.conv1.weight", "head_models.1.layer1.0.bn1.weight", "head_models.1.layer1.0.bn1.bias", "head_models.1.layer1.0.bn1.running_mean", "head_models.1.layer1.0.bn1.running_var", "head_models.1.layer1.0.bn1.num_batches_tracked", "head_models.1.layer1.0.conv2.weight", "head_models.1.layer1.0.bn2.weight", "head_models.1.layer1.0.bn2.bias", "head_models.1.layer1.0.bn2.running_mean", "head_models.1.layer1.0.bn2.running_var", "head_models.1.layer1.0.bn2.num_batches_tracked", "head_models.1.layer1.1.conv1.weight", "head_models.1.layer1.1.bn1.weight", "head_models.1.layer1.1.bn1.bias", "head_models.1.layer1.1.bn1.running_mean", "head_models.1.layer1.1.bn1.running_var", "head_models.1.layer1.1.bn1.num_batches_tracked", "head_models.1.layer1.1.conv2.weight", "head_models.1.layer1.1.bn2.weight", "head_models.1.layer1.1.bn2.bias", "head_models.1.layer1.1.bn2.running_mean", "head_models.1.layer1.1.bn2.running_var", "head_models.1.layer1.1.bn2.num_batches_tracked", "head_models.1.layer1.2.conv1.weight", "head_models.1.layer1.2.bn1.weight", "head_models.1.layer1.2.bn1.bias", "head_models.1.layer1.2.bn1.running_mean", "head_models.1.layer1.2.bn1.running_var", "head_models.1.layer1.2.bn1.num_batches_tracked", "head_models.1.layer1.2.conv2.weight", "head_models.1.layer1.2.bn2.weight", "head_models.1.layer1.2.bn2.bias", "head_models.1.layer1.2.bn2.running_mean", "head_models.1.layer1.2.bn2.running_var", "head_models.1.layer1.2.bn2.num_batches_tracked", "head_models.1.layer2.0.conv1.weight", "head_models.1.layer2.0.bn1.weight", "head_models.1.layer2.0.bn1.bias", "head_models.1.layer2.0.bn1.running_mean", "head_models.1.layer2.0.bn1.running_var", "head_models.1.layer2.0.bn1.num_batches_tracked", "head_models.1.layer2.0.conv2.weight", "head_models.1.layer2.0.bn2.weight", "head_models.1.layer2.0.bn2.bias", "head_models.1.layer2.0.bn2.running_mean", "head_models.1.layer2.0.bn2.running_var", "head_models.1.layer2.0.bn2.num_batches_tracked", "head_models.1.layer2.0.shortcut.0.weight", "head_models.1.layer2.0.shortcut.1.weight", "head_models.1.layer2.0.shortcut.1.bias", "head_models.1.layer2.0.shortcut.1.running_mean", "head_models.1.layer2.0.shortcut.1.running_var", "head_models.1.layer2.0.shortcut.1.num_batches_tracked", "head_models.1.layer2.1.conv1.weight", "head_models.1.layer2.1.bn1.weight", "head_models.1.layer2.1.bn1.bias", "head_models.1.layer2.1.bn1.running_mean", "head_models.1.layer2.1.bn1.running_var", "head_models.1.layer2.1.bn1.num_batches_tracked", "head_models.1.layer2.1.conv2.weight", "head_models.1.layer2.1.bn2.weight", "head_models.1.layer2.1.bn2.bias", "head_models.1.layer2.1.bn2.running_mean", "head_models.1.layer2.1.bn2.running_var", "head_models.1.layer2.1.bn2.num_batches_tracked", "head_models.1.layer2.2.conv1.weight", "head_models.1.layer2.2.bn1.weight", "head_models.1.layer2.2.bn1.bias", "head_models.1.layer2.2.bn1.running_mean", "head_models.1.layer2.2.bn1.running_var", "head_models.1.layer2.2.bn1.num_batches_tracked", "head_models.1.layer2.2.conv2.weight", "head_models.1.layer2.2.bn2.weight", "head_models.1.layer2.2.bn2.bias", "head_models.1.layer2.2.bn2.running_mean", "head_models.1.layer2.2.bn2.running_var", "head_models.1.layer2.2.bn2.num_batches_tracked", "head_models.1.layer3.0.conv1.weight", "head_models.1.layer3.0.bn1.weight", "head_models.1.layer3.0.bn1.bias", "head_models.1.layer3.0.bn1.running_mean", "head_models.1.layer3.0.bn1.running_var", "head_models.1.layer3.0.bn1.num_batches_tracked", "head_models.1.layer3.0.conv2.weight", "head_models.1.layer3.0.bn2.weight", "head_models.1.layer3.0.bn2.bias", "head_models.1.layer3.0.bn2.running_mean", "head_models.1.layer3.0.bn2.running_var", "head_models.1.layer3.0.bn2.num_batches_tracked", "head_models.1.layer3.0.shortcut.0.weight", "head_models.1.layer3.0.shortcut.1.weight", "head_models.1.layer3.0.shortcut.1.bias", "head_models.1.layer3.0.shortcut.1.running_mean", "head_models.1.layer3.0.shortcut.1.running_var", "head_models.1.layer3.0.shortcut.1.num_batches_tracked", "head_models.1.layer3.1.conv1.weight", "head_models.1.layer3.1.bn1.weight", "head_models.1.layer3.1.bn1.bias", "head_models.1.layer3.1.bn1.running_mean", "head_models.1.layer3.1.bn1.running_var", "head_models.1.layer3.1.bn1.num_batches_tracked", "head_models.1.layer3.1.conv2.weight", "head_models.1.layer3.1.bn2.weight", "head_models.1.layer3.1.bn2.bias", "head_models.1.layer3.1.bn2.running_mean", "head_models.1.layer3.1.bn2.running_var", "head_models.1.layer3.1.bn2.num_batches_tracked", "head_models.1.layer3.2.conv1.weight", "head_models.1.layer3.2.bn1.weight", "head_models.1.layer3.2.bn1.bias", "head_models.1.layer3.2.bn1.running_mean", "head_models.1.layer3.2.bn1.running_var", "head_models.1.layer3.2.bn1.num_batches_tracked", "head_models.1.layer3.2.conv2.weight", "head_models.1.layer3.2.bn2.weight", "head_models.1.layer3.2.bn2.bias", "head_models.1.layer3.2.bn2.running_mean", "head_models.1.layer3.2.bn2.running_var", "head_models.1.layer3.2.bn2.num_batches_tracked", "head_models.1.linear.weight", "head_models.1.linear.bias", "merged_model.conv1.weight", "merged_model.bn1.weight", "merged_model.bn1.bias", "merged_model.bn1.running_mean", "merged_model.bn1.running_var", "merged_model.bn1.num_batches_tracked", "merged_model.layer1.0.conv1.weight", "merged_model.layer1.0.bn1.weight", "merged_model.layer1.0.bn1.bias", "merged_model.layer1.0.bn1.running_mean", "merged_model.layer1.0.bn1.running_var", "merged_model.layer1.0.bn1.num_batches_tracked", "merged_model.layer1.0.conv2.weight", "merged_model.layer1.0.bn2.weight", "merged_model.layer1.0.bn2.bias", "merged_model.layer1.0.bn2.running_mean", "merged_model.layer1.0.bn2.running_var", "merged_model.layer1.0.bn2.num_batches_tracked", "merged_model.layer1.1.conv1.weight", "merged_model.layer1.1.bn1.weight", "merged_model.layer1.1.bn1.bias", "merged_model.layer1.1.bn1.running_mean", "merged_model.layer1.1.bn1.running_var", "merged_model.layer1.1.bn1.num_batches_tracked", "merged_model.layer1.1.conv2.weight", "merged_model.layer1.1.bn2.weight", "merged_model.layer1.1.bn2.bias", "merged_model.layer1.1.bn2.running_mean", "merged_model.layer1.1.bn2.running_var", "merged_model.layer1.1.bn2.num_batches_tracked", "merged_model.layer1.2.conv1.weight", "merged_model.layer1.2.bn1.weight", "merged_model.layer1.2.bn1.bias", "merged_model.layer1.2.bn1.running_mean", "merged_model.layer1.2.bn1.running_var", "merged_model.layer1.2.bn1.num_batches_tracked", "merged_model.layer1.2.conv2.weight", "merged_model.layer

gstoica27 commented 1 year ago

Hi,

Thanks! And for your interest in our work :)

Our code merges all models at once, rather than sequentially. It is possible to merge sequentially as you've described though, by first taking two base models and merging them (by defining their graphs, a merge object and calling transform), and then taking the now merged model and another base model and repeating the process (so creating new graphs of the merged model and base model, a new merge object and transform), etc...

Regarding the error, it is difficult to debug this fully without seeing your modified script, but it looks like you're saving the entire Merge object, rather than the merged model. So you're getting a mismatch of state dict keys because you're trying to load a different model into the base model. Assuming you're doing a full merge, you should instead load the "merged_model" weights into the base, so save and load the state_dict of "Merge.merged_model".

Hope this helps!

nsaadati commented 1 year ago

that fixed it thank you so much. but I'm running to another issue, after merging the merged models the accuracy it's very low around 0.05, do you have any idea what might be the resaon? I also wanted to know why it's crucial to have the same number of classes. Suppose I have one task with 80 classes and another with only 20. If I attempt to merge these tasks, I encounter an error. Do you happen to know how I can resolve this issue? Files already downloaded and verified 50000it [00:00, 749614.68it/s] 50000it [00:00, 1361654.38it/s] Files already downloaded and verified 10000it [00:00, 727634.58it/s] 10000it [00:00, 1350953.07it/s] Preparing Models: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:07<00:00, 3.89s/it] Resetting batch norm: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:12<00:00, 79.46it/s] Resetting batch norm: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:13<00:00, 73.52it/s] Forward Pass to Compute Merge Metrics: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:20<00:00, 49.34it/s] Forward Pass to Compute Merge Metrics: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 999/1000 [00:20<00:00, 50.30it/s/home/exouser/SZip/ZipIt/matching_functions.py:84: UserWarning: floordiv is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). col_idx = best_idx // sims.shape[1] Computing transformations: 25%|███████████████████████████████████████▊ | 3/12 [00:02<00:08, 1.01it/s] Resetting batch norm: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:22<00:00, 44.87it/s] Evaluating Pairs...: 0%| | 0/1 [01:22<?, ?it/s] Traceback (most recent call last): File "/home/exouser/.conda/envs/cf/lib/python3.8/runpy.py", line 194, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/exouser/.conda/envs/cf/lib/python3.8/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/exouser/SZip/ZipIt/evaluation_scripts/zipit_concept_merging.py", line 143, in run_node_experiment( File "/home/exouser/SZip/ZipIt/evaluation_scripts/zipit_concept_merging.py", line 90, in run_node_experiment results = evaluate_model(experiment_config['eval_type'], Merge, config) File "/home/exouser/SZip/ZipIt/utils.py", line 609, in evaluate_model acc_overall, acc_avg, perclass_acc = evaluate_logits_alltasks( File "/home/exouser/SZip/ZipIt/utils.py", line 406, in evaluate_logits_alltasks splits = torch.tensor(splits).to(device) ValueError: expected sequence of length 80 at dim 1 (got 20)

gstoica27 commented 1 year ago

Hey, for:

  1. you can save the merged_model by calling save on Merge.merged_model
  2. You’re getting the 80, 20 error because the eval code assumes the label splits to be the same size, but in principle the splits don’t have to be. You can just treat the model outputs/labels as a list of tensors rather than a stacked tensor as is done in the code.
nsaadati commented 1 year ago

thanks a lot. Did you see my new question about the merged model accuracy? these are the results for merging the models after second time and it's surprisingly low, do you have any idea about what might be the reason? image

gstoica27 commented 1 year ago

Interesting, how are you normalizing when applying a merge in this sequential setting? You need to make sure that every time you merge you are renormalizing all models involved in that merge. So for two models, you need to divide their weights by 2, for 3 models you need to divide by 3, etc…

nsaadati commented 1 year ago

I did not change anything in the code for this part can you please tell me which part of code is related to this?

gstoica27 commented 1 year ago

Sure - the repository currently does not support what I described above, but it can be easily adapted to do so. Normalization between weights of base models before merging happens right before they are added together to create a merged model (please see line 349 in model_merger.py for details). What you would have to do is change the line from uniformly averaging all the graph model weights to instead divide each graph model by a function of number of models it is composed of.

For instance to merge an already merged model of two base models (call it graph 1) with one additional base model (call it graph 2), you would need to have the following:

merged_model = graph1 2/3 + graph21/3

With a similar process for merging more and more models sequentially.