ogkalu2 / Merge-Stable-Diffusion-models-without-distortion

Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
MIT License
139 stars 21 forks source link

specifying a custom device crashes the merge #3

Open Pyr-000 opened 1 year ago

Pyr-000 commented 1 year ago

Specifying a device (e.g. .\.venv\Scripts\python.exe .\SD_rebasin_merge.py --model_a .\standard_models\mix_in_1 --model_b .\standard_models\mix_in_2 --output .\standard_models\mixed --device cuda ) crashes the script:

Traceback (most recent call last):
  File ".\SD_rebasin_merge.py", line 23, in <module>
    final_permutation = weight_matching(permutation_spec, state_a, state_b)
  File ".\weight_matching.py", line 796, in weight_matching
    w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
  File ".\weight_matching.py", line 773, in get_permuted_param
    w = torch.index_select(w, axis, perm[p].int())
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper__index_select)
Pyr-000 commented 1 year ago

I should note that I can't fully test this, as I am far from having enough VRAM to do the full merge on GPU. When running on CPU, it seems to peak at over 60GB of memory.