Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
Specifying a device (e.g. .\.venv\Scripts\python.exe .\SD_rebasin_merge.py --model_a .\standard_models\mix_in_1 --model_b .\standard_models\mix_in_2 --output .\standard_models\mixed --device cuda ) crashes the script:
Traceback (most recent call last):
File ".\SD_rebasin_merge.py", line 23, in <module>
final_permutation = weight_matching(permutation_spec, state_a, state_b)
File ".\weight_matching.py", line 796, in weight_matching
w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
File ".\weight_matching.py", line 773, in get_permuted_param
w = torch.index_select(w, axis, perm[p].int())
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
I should note that I can't fully test this, as I am far from having enough VRAM to do the full merge on GPU. When running on CPU, it seems to peak at over 60GB of memory.
Specifying a device (e.g.
.\.venv\Scripts\python.exe .\SD_rebasin_merge.py --model_a .\standard_models\mix_in_1 --model_b .\standard_models\mix_in_2 --output .\standard_models\mixed --device cuda
) crashes the script: