ogkalu2 / Merge-Stable-Diffusion-models-without-distortion

Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
MIT License
139 stars 21 forks source link

RuntimeError: expected scalar type Float but found Half #5

Closed thojmr closed 1 year ago

thojmr commented 1 year ago

Hey there, trying out this merge implementation for the first time but running this consistent issue. Seems to happen for many different combinations of model merges. Both novel based and sd-v1-4 based models.

I should also mention that I am using the "device id" pull request https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion/pull/4, but even without that I get the same error.

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 23, in <module>
    final_permutation = weight_matching(permutation_spec, state_a, state_b, device)
  File "/mnt/c/Users/<redacted>/ML/apps/merge-stable-diffusion-models-without-distortion/weight_matching.py", line 800, in weight_matching
    A += w_a @ w_b.T
RuntimeError: expected scalar type Float but found Half
IdiotSandwichTheThird commented 1 year ago

You have to download and use the 32bit models.

thojmr commented 1 year ago

Makes sense. It would be nice if it would just autocast to 32bit. Everyone that uses this script is going to run into the same problem.

This makes it work on CPU at least

        w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)).to(torch.float32)
        w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1)).to(torch.float32)
thojmr commented 1 year ago

Latest merge #7 should have fixed this