Open ogkalu2 opened 2 years ago
Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct PermutationSpec
for your model.
In general writing down the PermutationSpec
will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate the PermutationSpec
for you but this is not ready for release just yet.
Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!
Hi @samuela Thanks for responding. Quite a bit's happened since the last comment.
I found a repo that had converted the code to pytorch. I also actually got down to writing the permutation spec for stable diffusion today. I think i'm on the right track but i'm not too sure. For instance, i'm not quite sure what the correct p_in and p_out values should be. Running it right now gives a couple different errors each time.
Sometimes i get something like
File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 487, in weight_matching A += w_a @ w_b.T RuntimeError: mat1 and mat2 shapes cannot be multiplied (1280x11520 and 5760x1280)
or
"addmm_implcpu" not implemented for 'Half'
Can you take a quick look here and see what you think I might be doing wrong ?
This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.
This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.
I'll try that thanks.
I also get "RuntimeError: INDICES element is out of DATA bounds, id=256 axis_dim=256"
This looks to be an error in your permutation spec, I would try debugging what weight arrays that's occurring on.
I've tried a couple different tests now removing everything besides a few lines of layers that were explicitly labelled in the state dict to see if that was the issue and i could build from that.
Example here https://imgur.com/a/IbAEfdP and https://imgur.com/a/o6SgMk8
But i still get those weird errors.
@samuela
I finally got it to run the final permutation function for the few explicit blocks/layers (i'll build from there once i see everything else works) !
I feel so close but i've hit another wall on the apply permutation line. It's giving me "Key Erorr: Betas"
The good news is that i'm pretty sure i know what's happening here. Betas is the first key in Stable Diffusion's State Dict. It seems to be stuck trying to apply the permutation to Betas but betas wasn't defined in the permutation_spec list of layers or blocks to alter. More importantly, it's not the only key in the state dict i thought best left undisturbed.
Is there any way i can get the apply permutation function to skip keys that it doesn't have altered values for ?
Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?
You could always add them to you PermutationSpec
and just specify None
for all the axes. That should make apply_permutation
ignore them.
(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)
You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others.
Anything that isn't either in diffusionmodel
, first_stage_model
or cond_stage_model
is such a parameter.
Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.
UNet is the interesting part, and this resides in diffusionmodel.model.*
EDIT: to clarify, these are the keys you can skip
betas
alphas_cumprod
alphas_cumprod_prev
sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod
log_one_minus_alphas_cumprod
sqrt_recip_alphas_cumprod
sqrt_recipm1_alphas_cumprod
posterior_variance
posterior_log_variance_clipped
posterior_mean_coef1
posterior_mean_coef2
model_ema.decay
model_ema.num_updates
and these are the ones you can skip if you retain the same VAE and text encoder
cond_stage_model.*
first_stage_model.*
Hmm, I'm not familiar with the stable diffusion architecture... is there a reason not to model permutations on the betas?
You could always add them to you
PermutationSpec
and just specifyNone
for all the axes. That should makeapply_permutation
ignore them.(Btw, since you're using the 3rd-party pytorch implementation some things may be different! That's a different codebase.)
Mostly 3 reasons
I can't tell the layer type of all the layers in the state_dict, only most of them. So i have to skip some to test it out first. Also some layers aren't really part of the architecture and will be the same between 2 models .
A few correctly identified layers don't work well for some reason. I get an error like
File "/content/drive/MyDrive/SD_rebasin/weight_matching.py", line 293, in weight_matching w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1)) RuntimeError: shape '[512, -1]' is invalid for input of size 768
For axes you mean the P_bgx and P_bgy values right ?. So None and None then ?
I have a question on that too. There are a lot of layers. I'm unsure how to correctly label them all. For the layers with 2 P values, do i just keep going sequentially ? I reach P_bg50 or so that way.
For the layers with only one , how would that work exactly ? It's a bit hard to tell when i need to go from say P_bg1 to P_bg2 and to P_bg3 and so on. The architecture is divided in 3 parts - The input blocks, the middle blocks and the output blocks. So I'm wondering, is it P_bg1, _bg2, P_bg3 for those set of blocks or something else ?
You can skip betas, it's not a layer of the actual architecture but a stored parameter used for sampling and generating timestep embeddings. Same goes for alphas_cumprod, sqrt_alphas, and a dozen others. Anything that isn't either in
diffusionmodel
,first_stage_model
orcond_stage_model
is such a parameter. Actually you can probably skip first_stage_model (VAE) and cond_stage_model (CLIP) as well as those will probably be the same between two models. If not it would still make more sense to spec them separately as they are indeed separate models that only interact with each other manually, e.g. they have no connection in the model graph.UNet is the interesting part, and this resides in
diffusionmodel.model.*
EDIT: to clarify, these are the keys you can skip
betas alphas_cumprod alphas_cumprod_prev sqrt_alphas_cumprod sqrt_one_minus_alphas_cumprod log_one_minus_alphas_cumprod sqrt_recip_alphas_cumprod sqrt_recipm1_alphas_cumprod posterior_variance posterior_log_variance_clipped posterior_mean_coef1 posterior_mean_coef2 model_ema.decay model_ema.num_updates
and these are the ones you can skip if you retain the same VAE and text encoder
cond_stage_model.* first_stage_model.*
Thanks. Lots of dreambooth repos train the text encoder also now so i won't skip them.
I have some uncertainty on the type of certain layers.
The layers that have emb_layers proj_in ptoj_out transformer_blocks skip_connection self_attn mid.attn
Do you have any idea ? Are they norm, conv, dense or neither ?
Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.
- emb_layers is SiLU + linear -> dense
- proj_in is Conv2d
- proj_out is Conv2d
transformer_blocks is BasicTransformerBlock -> CrossAttention, LayerNorm, FeedForward
CrossAttention: all linear layers -> dense
FeedForward: GEGLU + linear -> dense
GEGLU: linear + gelu -> dense
- skip_connection: is either Identity or Conv2d (Identity has no weights, so going by key, its always Conv2d)
- self_attn: CLIPAttention -> dense
- mid.attn: AttnBlock -> Conv2d
Keep in mind that the different Attention layer types might be either all conv or dense, but they are not just sequentially chained. Rather they do qkv queries and batch matrix mult on results. I'm not sure if that has any impact on the results of the permutations.
Thanks for the response. It's helped a lot. Yes i think i'm going to skip the attention layers, at least for the first go around. As samuel suggested, Labeling the axes as none for betas worked on skipping it so i just have to do them all now.
Forgot to ask, what are the time_embed (i'm assuming dense now), model.out.0 and out.2 (norm and conv i think) and .op (i think conv ) ?
unet time_embed: linear, silu, linear, so its dense unet out: GroupNorm, SiLU, Conv2d op (which I guess you mean op of the unet downsampling block) is conv2d
It's all readily available in the implementation, so I suggest you read it yourself to get a better understanding. unet: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/openaimodel.py unet attention: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/attention.py vae: https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/autoencoder.py
Just an update. At first I had trouble building up. Some layers would work, most wouldn't. But most irritatingly, it felt inconsistent on what would work without error and what wouldn't. Today I figured out the issue was the axis and the torch size of the layers. Not all layers can or should be connected by axis and the torch size help tell which ones can/should be. Anyway, I can get pretty much every later to permutate now. So I'll finish that and finally test this.
@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.
also, you probably saw this, but there's a PyTorch version, but I think you need to come up with a PermutationSpec: https://github.com/themrzmaster/git-re-basin-pytorch
@ogkalu2 So the idea is to go pytorch -> TF -> rebasin -> pytorch? This sounds huge btw, thanks for doing it.
I'm pretty just using the pytorch implementation now. The one you linked, I already knew about it. Ended up using Jax for flattening and unflattening the parmeters but that's about it.
No problem. It's my pleasure. Done with the unet. Working on the text encoder. There's no doubt it'll run now. Just if it merges as hoped. Fingers crossed for that.
@ogkalu2 Do you have a repository for this where I could take a look?
@lopho No. I wanted to finish things and see the results of a merged model before i uploaded anything to a repo.
@ogkalu2 Do you have a repository for this where I could take a look?
Although i did upload my first attempt here. A few things have changed to make it work, mostly the axes https://imgur.com/a/DMnb7P8
But i added a bias option for the conv and added the dense emb layers in the easyblock
Hi @samuela would this run much faster on a gpu ?
Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only
Yes, it should run quite a bit faster on a GPU since that will speed up the matrix multiples but the linear assignment problem solve still happens on the CPU, so I don't think the speedup you'd get would be anything too crazy... I've never tried running on CPU only
Ah I see. If I don't specify device as cpu, i get
File "/notebooks/weight_matching.py", line 798, in weight_matching A += w_a @ w_b.T RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)
@lopho @affableroots @samuela
I'm done with the spec. I've not tested a merge yet but that will take time so i've uploaded it here for anyone to test as well.
https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion
Hi @samuela How many iterations does the weight matching typically run ? I know max iterations is 100 but it doesn't usually go that high, i don't think ?
It totally depends on the model and initialization. I've seen it take as few as 3 and as many as 50. It is guaranteed to terminate though, so don't worry it can't run forever!
I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?
Oh wow. I know it can't run forever but the 1st SD iteration took ~ 12 hours so I was curious. Ah well. What do the NewL - OldL values indicate exactly ? I see most of them are 0.0
I keep OOMing on 32GB RAM, any tips on what I can delete when, or maybe running the merge in parts?
Really ? Huh. I'm just running on vast right now. You can't really run in parts right now. As for what to delete, it's possible to skip some layers but i honestly don't know exactly what i can skip yet. The vae layers would be the first thing i'd remove but i don't know besides that. I'll look into that.
The OOM errors seem odd though. Do you actually get those errors on your console/terminal or does your system freeze up or something ?
Watching htop
, I watch it OOM, and also, I get a Killed
at the following step. If it matters, I'm testing just merging 2 simple 4GB .ckpt
s built off of sd-v1-4
...
0/P_model.diffusion_model.output_blocks.6.0_inner2: 0.0
0/P_bg308: 0.0
0/P_model.diffusion_model.middle_block.2_inner4: 0.0
0/P_bg353: 0.0
0/P_bg166: 0.0
0/P_bg180: 0.0
0/P_bg65: 0.0
0/P_bg78: 0.0
0/P_first_stage_model.encoder.mid.block_1_inner: 0.0
0/P_bg214: 0.0
0/P_first_stage_model.decoder.up.3.block.1_inner: 0.0
0/P_bg313: 0.0
0/P_model.diffusion_model.output_blocks.0.0_inner: 0.0
0/P_bg98: 0.0
0/P_bg359: 0.0
0/P_bg141: 0.0
0/P_bg264: 0.0
0/P_bg163: 0.0
Killed
EDIT: skipping the vae makes sense, that's a good idea.
Oh i see. The test i have running has 2 dreambooth models pruned to 2GB. The bigger the size of the models, the higher the RAM usage. I didn't realize 4gb models were too much for 32 GB ram systems currently. The problem is the linear sum assignment. It can only run on the CPU
@samuela Have someone here who's running through this in minutes lol https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion/issues/1#issuecomment-1319102078
Anyway i have a new problem now. So the perm spec runs fine and the parameters get updated fine. The previous line i wrote to save the model won't work. After defining the state dict(s) as state_a = model_a["state_dict"], i tried to save the model with torch.save({ "state_dict": state_b(updated_params) }, output_file)
but get hit with
"state_dict": state_b(updated_params) TypeError: 'dict' object is not callable
Hi @samuela Something seems to be up with the get_permuted_params function. applying the permutation with it just seems to bias whatever the selected model parameters are.
So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.
Any idea what the issue might be ?
So a merge with apply_permutation(permutation_spec, final_permutation, mode_a state dict) just produces a model that is basically model a and a merge with apply_permutation(permutation_spec, final_permutation, mode_b state dict) just produces a model that is basically model b.
Hi @ogkalu2, how are you measuring the difference between the permuted model and the original? Have you inspected the final_permutations
to see if they are close to identity? Depending on the kind of fine-tuning you're doing, it could be reasonably expected that the optimal permutation is already close to identity. This is a convenient property of fine-tuning that you generally don't leave the pre-training basin, esp. with large models and small learning rates, see eg https://arxiv.org/abs/2109.01903, https://twitter.com/moyix/status/1581390268368302080, and so forth.
Hi @ogkalu2! The exact framework that the model runs in is not super important. For example, we have used our JAX code to align the weights of two PyTorch models in the past. The only important part is that you can load the weights into Python/JAX and that you have a correct
PermutationSpec
for your model.In general writing down the
PermutationSpec
will be the more challenging step. We are currently working on a pure-PyTorch version, including a "tracer" that can automatically generate thePermutationSpec
for you but this is not ready for release just yet.Would be very cool to see if this works on StableDiffusion models! Do let me know if you get it working!
Hello, I wonder if there is an update on the automatic tracer in pytorch for the permutation spec? Thanks!
Hi Samuel. Thank you for being willing to look at this. Basically, i'm trying to see if it is possible to merge stable diffusion models i've finetuned with dreambooth with your method.
The first hurdle of course is that your implementation is not yet compatible with pytorch as far as i know. But the pytorch weights can be successfully converted to Tensorflow weights. This has been done before as well. I don't mind doing this if i have to.
The second hurdle will be successfully using your implementation on the TF models. I'm not sure how feasible this all is or how i would use your code on the SD TensorFlow weights.