ogkalu2 / Merge-Stable-Diffusion-models-without-distortion

Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
MIT License
139 stars 21 forks source link

KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids' #28

Open R-N opened 1 year ago

R-N commented 1 year ago

Error:

/content/merger
Using half precision

    ---------------------
         ITERATION 1
    ---------------------

new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'

Code:

!pip install pytorch-lightning torch==1.11.0+cu113 torchvision==0.12.0+cu113 

!git clone https://github.com/ogkalu2/Merge-Stable-Diffusion-models-without-distortion merger
%cd /content/merger

#download models from hf
!curl ...

model_a_path = "novelai.ckpt"
model_b_path = "sd_1.5.ckpt"
output_name = f"NAI_f222_0.45ws.ckpt"
alpha = 0.45
device = "cuda"

%cd /content/merger
!python SD_rebasin_merge.py --model_a {model_a_path} --model_b {model_b_path} --output {output_name} --alpha {alpha} --device {device}

RAM usage: image

R-N commented 1 year ago

I've tried adding it to the skipped layers as mentioned here. It's in weight_matching.py right?

  return permutation_spec_from_axes_to_perm({
     #Skipped Layers
     **skip("betas", None, None),
     **skip("alphas_cumprod", None, None),
     **skip("alphas_cumprod_prev", None, None),
     **skip("sqrt_alphas_cumprod", None, None),
     **skip("sqrt_one_minus_alphas_cumprod", None, None),
     **skip("log_one_minus_alphas_cumprods", None, None),
     **skip("sqrt_recip_alphas_cumprod", None, None),
     **skip("sqrt_recipm1_alphas_cumprod", None, None),
     **skip("posterior_variance", None, None),
     **skip("posterior_log_variance_clipped", None, None),
     **skip("posterior_mean_coef1", None, None),
     **skip("posterior_mean_coef2", None, None),
     **skip("log_one_minus_alphas_cumprod", None, None),
     **skip("model_ema.decay", None, None),
     **skip("model_ema.num_updates", None, None),
     **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None),

     #initial 

But I'm still getting the same error:

/content/merger
Using half precision

    ---------------------
         ITERATION 1
    ---------------------

new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'
R-N commented 1 year ago

So I'm worried that my changes didn't apply so I printed something, and it's printed. So I think my change applied.

  print("HELLLOOOO")
  return permutation_spec_from_axes_to_perm({
     #Skipped Layers 
     **skip("betas", None, None),
     **skip("alphas_cumprod", None, None),
     **skip("alphas_cumprod_prev", None, None),
     **skip("sqrt_alphas_cumprod", None, None),
     **skip("sqrt_one_minus_alphas_cumprod", None, None),
     **skip("log_one_minus_alphas_cumprods", None, None),
     **skip("sqrt_recip_alphas_cumprod", None, None),
     **skip("sqrt_recipm1_alphas_cumprod", None, None),
     **skip("posterior_variance", None, None),
     **skip("posterior_log_variance_clipped", None, None),
     **skip("posterior_mean_coef1", None, None),
     **skip("posterior_mean_coef2", None, None),
     **skip("log_one_minus_alphas_cumprod", None, None),
     **skip("model_ema.decay", None, None),
     **skip("model_ema.num_updates", None, None),
     **skip("cond_stage_model.transformer.text_model.embeddings.position_ids", None, None),

     #initial 

Error is still the same:

/content/merger
HELLLOOOO
Using half precision

    ---------------------
         ITERATION 1
    ---------------------

new alpha = 0.045

Traceback (most recent call last):
  File "SD_rebasin_merge.py", line 55, in <module>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
  File "SD_rebasin_merge.py", line 55, in <dictcomp>
    theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1}
KeyError: 'cond_stage_model.transformer.text_model.embeddings.position_ids'
R-N commented 1 year ago

Weird. I checked your original weight_matching.py and it's already skipped.

zwishenzug commented 1 year ago

One of your models might not have the CLIP model embedded. Some versions of protogen suffer from this, maybe others.

In the main script near the start, where it loads the models:

model_a = torch.load(args.model_a, map_location=device)
model_b = torch.load(args.model_b, map_location=device)

You could try some code like this right after those two lines (untested) to copy across the CLIP model from the other model if one of them is missing it:

for key in model_a['state_dict'].keys():
    if 'cond_stage_model.' in key:
        if not key in model_b['state_dict']:
            model_b['state_dict'][key] = model_a['state_dict'][key].clone().detach()

for key in model_b['state_dict'].keys():
    if 'cond_stage_model.' in key:
        if not key in model_a['state_dict']:
            model_a['state_dict'][key] = model_b['state_dict'][key].clone().detach()
ogkalu2 commented 1 year ago

@R-N Did @zwishenzug 's suggestion help ?