samuela / git-re-basin

Code release for "Git Re-Basin: Merging Models modulo Permutation Symmetries"
https://arxiv.org/abs/2209.04836
MIT License
470 stars 40 forks source link

Re-basin and Stable Diffusion Tensor Flow weights #5

Open ogkalu2 opened 2 years ago

ogkalu2 commented 2 years ago

Hi Samuel. Thank you for being willing to look at this. Basically, i'm trying to see if it is possible to merge stable diffusion models i've finetuned with dreambooth with your method.

The first hurdle of course is that your implementation is not yet compatible with pytorch as far as i know. But the pytorch weights can be successfully converted to Tensorflow weights. This has been done before as well. I don't mind doing this if i have to.

The second hurdle will be successfully using your implementation on the TF models. I'm not sure how feasible this all is or how i would use your code on the SD TensorFlow weights.

samuela commented 2 years ago

Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct PermutationSpec for your model.

In general writing down the PermutationSpec will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate the PermutationSpec for you but this is not ready for release just yet.

Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!

ogkalu2 commented 2 years ago

Hi @samuela Thanks for responding. Quite a bit's happened since the last comment.

I found a repo that had converted the code to pytorch. I also actually got down to writing the permutation spec for stable diffusion today. I think i'm on the right track but i'm not too sure. For instance, i'm not quite sure what the correct p_in and p_out values should be. Running it right now gives a couple different errors each time.

Sometimes i get something like

File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 487, in weight_matching A += w_a @ w_b.T RuntimeError: mat1 and mat2 shapes cannot be multiplied (1280x11520 and 5760x1280)

or

"addmm_implcpu" not implemented for 'Half'

Can you take a quick look here and see what you think I might be doing wrong ?

https://imgur.com/a/DMnb7P8

samuela commented 2 years ago

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

ogkalu2 commented 2 years ago

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

I'll try that thanks.

I also get "RuntimeError: INDICES element is out of DATA bounds, id=256 axis_dim=256"

ogkalu2 commented 2 years ago

This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.

I've tried a couple different tests now removing everything besides a few lines of layers that were explicitly labelled in the state dict to see if that was the issue and i could build from that.

Example here https://imgur.com/a/IbAEfdP and https://imgur.com/a/o6SgMk8

But i still get those weird errors.

ogkalu2 commented 2 years ago

@samuela
I finally got it to run the final permutation function for the few explicit blocks/layers (i'll build from there once i see everything else works) !

I feel so close but i've hit another wall on the apply permutation line. It's giving me "Key Erorr: Betas"

The good news is that i'm pretty sure i know what's happening here. Betas is the first key in Stable Diffusion's State Dict. It seems to be stuck trying to apply the permutation to Betas but betas wasn't defined in the permutation_spec list of layers or blocks to alter. More importantly, it's not the only key in the state dict i thought best left undisturbed.

Is there any way i can get the apply permutation function to skip keys that it doesn't have altered values for ?

samuela commented 2 years ago

Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?

You could always add them to you PermutationSpec and just specify None for all the axes. That should make apply_permutation ignore them.

(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)

lopho commented 2 years ago

You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others. Anything that isn't either in diffusionmodel, first_stage_model or cond_stage_model is such a parameter. Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.

UNet is the interesting part, and this resides in diffusionmodel.model.*

EDIT: to clarify, these are the keys you can skip

betas
alphas_cumprod
alphas_cumprod_prev
sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod
log_one_minus_alphas_cumprod
sqrt_recip_alphas_cumprod
sqrt_recipm1_alphas_cumprod
posterior_variance
posterior_log_variance_clipped
posterior_mean_coef1
posterior_mean_coef2
model_ema.decay
model_ema.num_updates

and these are the ones you can skip if you retain the same VAE and text encoder

cond_stage_model.*
first_stage_model.*
ogkalu2 commented 2 years ago

Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?

You could always add them to you PermutationSpec and just specify None for all the axes. That should make apply_permutation ignore them.

(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)

Mostly 3 reasons

File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 293, in weight_matching w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)) RuntimeError: shape '[512, -1]' is invalid for input of size 768

For axes you mean the P_bgx and P_bgy values right ?. So None and None then ?

I have a question on that too. There are a lot of layers. I'm unsure how to correctly label them all. For the layers with 2 P values, do i just keep going sequentially ? I reach P_bg50 or so that way.

For the layers with only one , how would that work exactly ? It's a bit hard to tell when i need to go from say P_bg1 to P_bg2 and to P_bg3 and so on. The architecture is divided in 3 parts - The input blocks, the middle blocks and the output blocks. So I'm wondering, is it P_bg1, _bg2, P_bg3 for those set of blocks or something else ?

ogkalu2 commented 2 years ago

You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others. Anything that isn't either in diffusionmodel, first_stage_model or cond_stage_model is such a parameter. Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.

UNet is the interesting part, and this resides in diffusionmodel.model.*

EDIT: to clarify, these are the keys you can skip

betas
alphas_cumprod
alphas_cumprod_prev
sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod
log_one_minus_alphas_cumprod
sqrt_recip_alphas_cumprod
sqrt_recipm1_alphas_cumprod
posterior_variance
posterior_log_variance_clipped
posterior_mean_coef1
posterior_mean_coef2
model_ema.decay
model_ema.num_updates

and these are the ones you can skip if you retain the same VAE and text encoder

cond_stage_model.*
first_stage_model.*

Thanks. Lots of dreambooth repos train the text encoder also now so i won't skip them.

I have some uncertainty on the type of certain layers.

The layers that have emb_layers proj_in ptoj_out transformer_blocks skip_connection self_attn mid.attn

Do you have any idea ? Are they norm, conv, dense or neither ?

lopho commented 2 years ago

Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.

ogkalu2 commented 2 years ago
  • emb_layers is SiLU + linear -> dense
  • proj_in is Conv2d
  • proj_out is Conv2d
  • transformer_blocks is BasicTransformerBlock -> CrossAttention, LayerNorm, FeedForward

    • CrossAttention: all linear layers -> dense

    • FeedForward: GEGLU + linear -> dense

    • GEGLU: linear + gelu -> dense

  • skip_connection: is either Identity or Conv2d (Identity has no weights, so going by key, its always Conv2d)
  • self_attn: CLIPAttention -> dense
  • mid.attn: AttnBlock -> Conv2d

Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.

Thanks for the response. It's helped a lot. Yes i think i'm going to skip the attention layers, at least for the first go around. As samuel suggested, Labeling the axes as none for betas worked on skipping it so i just have to do them all now.

Forgot to ask, what are the time_embed (i'm assuming dense now), model.out.0 and out.2 (norm and conv i think) and .op (i think conv ) ?

lopho commented 2 years ago

unet time_embed: linear, silu, linear, so its dense unet out: GroupNorm, SiLU, Conv2d op (which I guess you mean op of the unet downsampling block) is conv2d

It's all readily available in the implementation, so I suggest you read it yourself to get a better understanding. unet: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py unet attention: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py vae: https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/autoencoder.py

ogkalu2 commented 2 years ago

Just an update. At first I had trouble building up. Some layers would work, most wouldn't. But most irritatingly, it felt inconsistent on what would work without error and what wouldn't. Today I figured out the issue was the axis and the torch size of the layers. Not all layers can or should be connected by axis and the torch size help tell which ones can/should be. Anyway, I can get pretty much every later to permutate now. So I'll finish that and finally test this.

affableroots commented 2 years ago

@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.

affableroots commented 2 years ago

also, you probably saw this, but there's a PyTorch version, but I think you need to come up with a PermutationSpec: https://github.com/themrzmaster/git-re-basin-pytorch

ogkalu2 commented 2 years ago

@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.

I'm pretty just using the pytorch implementation now. The one you linked, I already knew about it. Ended up using Jax for flattening and unflattening the parmeters but that's about it.

No problem. It's my pleasure. Done with the unet. Working on the text encoder. There's no doubt it'll run now. Just if it merges as hoped. Fingers crossed for that.

lopho commented 2 years ago

@ogkalu2 Do you have a repository for this where I could take a look?

ogkalu2 commented 2 years ago

@lopho No. I wanted to finish things and see the results of a merged model before i uploaded anything to a repo.

ogkalu2 commented 2 years ago

@ogkalu2 Do you have a repository for this where I could take a look?

Although i did upload my first attempt here. A few things have changed to make it work, mostly the axes https://imgur.com/a/DMnb7P8

But i added a bias option for the conv and added the dense emb layers in the easyblock

ogkalu2 commented 1 year ago

Hi @samuela would this run much faster on a gpu ?

samuela commented 1 year ago

Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only

ogkalu2 commented 1 year ago

Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only

Ah I see. If I don't specify device as cpu, i get

File "/notebooks/weight_matching.py", line 798, in weight_matching A += w_a @ w_b.T RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm) ​

ogkalu2 commented 1 year ago

@lopho @affableroots @samuela
I'm done with the spec. I've not tested a merge yet but that will take time so i've uploaded it here for anyone to test as well. https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion

ogkalu2 commented 1 year ago

Hi @samuela How many iterations does the weight matching typically run ? I know max iterations is 100 but it doesn't usually go that high, i don't think ?

samuela commented 1 year ago

It totally depends on the model and initialization. I've seen it take as few as 3 and as many as 50. It is guaranteed to terminate though, so don't worry it can't run forever!

affableroots commented 1 year ago

I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?

ogkalu2 commented 1 year ago

Oh wow. I know it can't run forever but the 1st SD iteration took ~ 12 hours so I was curious. Ah well. What do the NewL - OldL values indicate exactly ? I see most of them are 0.0

ogkalu2 commented 1 year ago

I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?

Really ? Huh. I'm just running on vast right now. You can't really run in parts right now. As for what to delete, it's possible to skip some layers but i honestly don't know exactly what i can skip yet. The vae layers would be the first thing i'd remove but i don't know besides that. I'll look into that.

The OOM errors seem odd though. Do you actually get those errors on your console/terminal or does your system freeze up or something ?

affableroots commented 1 year ago

Watching htop, I watch it OOM, and also, I get a Killed at the following step. If it matters, I'm testing just merging 2 simple 4GB .ckpts built off of sd-v1-4

...
0/P_model.diffusion_model.output_blocks.6.0_inner2: 0.0
0/P_bg308: 0.0
0/P_model.diffusion_model.middle_block.2_inner4: 0.0
0/P_bg353: 0.0
0/P_bg166: 0.0
0/P_bg180: 0.0
0/P_bg65: 0.0
0/P_bg78: 0.0
0/P_first_stage_model.encoder.mid.block_1_inner: 0.0
0/P_bg214: 0.0
0/P_first_stage_model.decoder.up.3.block.1_inner: 0.0
0/P_bg313: 0.0
0/P_model.diffusion_model.output_blocks.0.0_inner: 0.0
0/P_bg98: 0.0
0/P_bg359: 0.0
0/P_bg141: 0.0
0/P_bg264: 0.0
0/P_bg163: 0.0
Killed

EDIT: skipping the vae makes sense, that's a good idea.

ogkalu2 commented 1 year ago

Oh i see. The test i have running has 2 dreambooth models pruned to 2GB. The bigger the size of the models, the higher the RAM usage. I didn't realize 4gb models were too much for 32 GB ram systems currently. The problem is the linear sum assignment. It can only run on the CPU

ogkalu2 commented 1 year ago

@samuela Have someone here who's running through this in minutes lol https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion/issues/1#issuecomment-1319102078

Anyway i have a new problem now. So the perm spec runs fine and the parameters get updated fine. The previous line i wrote to save the model won't work. After defining the state dict(s) as state_a = model_a["state_dict"], i tried to save the model with torch.save({ "state_dict": state_b(updated_params) }, output_file)

but get hit with

"state_dict": state_b(updated_params) TypeError: 'dict' object is not callable

ogkalu2 commented 1 year ago

Hi @samuela Something seems to be up with the get_permuted_params function. applying the permutation with it just seems to bias whatever the selected model parameters are.

So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.

Any idea what the issue might be ?

samuela commented 1 year ago

So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.

Hi @ogkalu2, how are you measuring the difference between the permuted model and the original? Have you inspected the final_permutations to see if they are close to identity? Depending on the kind of fine-tuning you're doing, it could be reasonably expected that the optimal permutation is already close to identity. This is a convenient property of fine-tuning that you generally don't leave the pre-training basin, esp. with large models and small learning rates, see eg https://arxiv.org/abs/2109.01903, https://twitter.com/moyix/status/1581390268368302080, and so forth.

liruiw commented 1 year ago

Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct PermutationSpec for your model.

In general writing down the PermutationSpec will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate the PermutationSpec for you but this is not ready for release just yet.

Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!

Hello, I wonder if there is an update on the automatic tracer in pytorch for the permutation spec? Thanks!