ogkalu2 / Merge-Stable-Diffusion-models-without-distortion

Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
MIT License
139 stars 21 forks source link

Changed to half precision (temporary). Use GPU for matrix multiplicat… #7

Closed vaguenebula closed 1 year ago

vaguenebula commented 1 year ago

…ion. Removed jax dependency. Runs very fast, merging in a few minutes. The reason I removed jax is to make dependencies simpler. Please let me know your thoughts. Either way, matrix multiplication using cuda is much faster (especially when multiplying 10000 by 10000 arrays). i also skipped one of the layers, which is too big to compute linear sum: **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None). Not all of these changes are perfect most likely, just tell me if I made any major mistakes, as I'm still a noob at this sort of stuff. I tested the merged model, and it works as expected btw. I can send example images if you want.

vaguenebula commented 1 year ago

Maybe precision can be one of the arguments, in case somebody doesn't have the resources to merge models at full precision.

IdiotSandwichTheThird commented 1 year ago

What were the ram/vram requirements you saw merging with these changes?

vaguenebula commented 1 year ago

Around 5 gb vram. 20gb ram. I just noticed that I forgot to use half precision in both models. thats an easy fix, just simply casting both to half precision

thojmr commented 1 year ago

Trying this PR now, but I run into this

    A += torch.matmul(w_a, w_b).cpu()
RuntimeError: expected scalar type Float but found Half
vaguenebula commented 1 year ago

Trying this PR now, but I run into this

    A += torch.matmul(w_a, w_b).cpu()
RuntimeError: expected scalar type Float but found Half

Sorry, that has to do with one of them being full precision and the other being half. This will potentially fix the issue for now: A += torch.matmul(w_a.half(), w_b).cpu()

vaguenebula commented 1 year ago

Btw, i tested the merged model against one merged with weighted sum. And the results are strange. Like its not bad or anything, but the results seem skewed to one model.

thojmr commented 1 year ago

Works now, its about twice as fast per iteration now. 4 min > 1.5-2 min

vaguenebula commented 1 year ago

Ok, so this doesn't work. I might've changed something important. I'm getting model b back as my merged model.

vaguenebula commented 1 year ago

Not the same, but HEAVILY skewed to model b.

thojmr commented 1 year ago

That's the norm right now. It always gives me model B with tiny changes, even before this PR

long discussion here https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion/issues/1

vaguenebula commented 1 year ago

That's the norm right now. It always gives me model B with tiny changes, even before this PR

long discussion here #1

Ok, i might look into that later. At least theres a way to test it fast now (if my code is actually decent).

7Tenku commented 1 year ago

I am getting: "dot" not implemented for 'half' Running this.

vaguenebula commented 1 year ago

I am getting: "dot" not implemented for 'half' Running this.

does it tell you what line the error is on?

7Tenku commented 1 year ago

Line 806

vaguenebula commented 1 year ago

Line 806

Try replacing line 806 with: assert (torch.tensor(ri, dtype=torch.float16) == torch.arange(len(ri), dtype=torch.float16)).all()

Not sure if this is gonna work, but its worth a try.

vaguenebula commented 1 year ago

Line 806

Try replacing line 806 with: assert (torch.tensor(ri, dtype=torch.float16) == torch.arange(len(ri), dtype=torch.float16)).all()

Not sure if this is gonna work, but its worth a try.

this isnt gonna work, nevermind. ill try to fix it later

7Tenku commented 1 year ago

okay so on line 806 and 807 I removed the .half at the end and i change line 794 A = torch.zeros((n, n), dtype=torch.float16) to A = torch.zeros((n, n), dtype=torch.float) it was actually running but now i am getting: Traceback (most recent call last): File "C:\Users\lx\Desktop\Merge-Stable-Diffusion-models-without-distortion-main\Merge-Stable-Diffusion-models-without-distortion-main\SD_rebasin_merge.py", line 27, in for axis, p in enumerate(permutation_spec.axes_to_perm[a]): KeyError: 'state_dict'

vaguenebula commented 1 year ago

what models are you trying to merge?

7Tenku commented 1 year ago

I want to add that I don't know any coding language and I am trying to fix it by what I think seems logical.

7Tenku commented 1 year ago

what models are you trying to merge?

a merged NAI model that I merged in webui and sd-v1-4-wd12 merge.

vaguenebula commented 1 year ago

try merging some different models to see if it works in general

vaguenebula commented 1 year ago

deduce which model is the issue as well. Frankly im new to this machine learning stuff, so my best guess is that it has something to do with the architecture of the NAI model. either that or theres something with the code thats wrong

7Tenku commented 1 year ago

Instead of the NAI model merge I used SD1.5 but I had to skip 2 layers. It did save a model this time.

vaguenebula commented 1 year ago

I'm going to add an option for fp16, so people can choose to merge at full precision if they want to

vaguenebula commented 1 year ago

Currently trying different things and seeing if I can fix the actual merging as well

ogkalu2 commented 1 year ago

@vaguenebula I only had jax dependencies yesterday. To test if that was the reason behind the biased to model b issue. Already removed it.

vaguenebula commented 1 year ago

@vaguenebula I only had jax dependencies yesterday. To test if that was the reason behind the biased to model b issue. Already removed it.

Sounds good, did you figure out whats causing the output to basically just be model b

ogkalu2 commented 1 year ago

@vaguenebula I only had jax dependencies yesterday. To test if that was the reason behind the biased to model b issue. Already removed it.

Sounds good, did you figure out whats causing the output to basically just be model b

Not really. Perhaps the rebasin method just might not be cut out for generator models. The NewL - OldL values should not be 0.0 for most of the layers. As it stands, about only 3 layers make any progress in iteration 0 before aborting. Forcing more iterations does nothing.