ogkalu2 / Merge-Stable-Diffusion-models-without-distortion

Adaptation of the merging method described in the paper - Git Re-Basin: Merging Models modulo Permutation Symmetries (https://arxiv.org/abs/2209.04836) for Stable Diffusion
MIT License
139 stars 21 forks source link

KeyError #36

Open robot-never-die opened 1 year ago

robot-never-die commented 1 year ago

on line: 55 theta_0 = {key: (1 - (new_alpha)) * theta_0[key] + (new_alpha) * value for key, value in theta_1.items() if "model" in key and key in theta_1} and line: 59 if "model" in key and key not in theta_0: theta_0[key] = theta_1[key]

if "model" is to generic what key tree need to modified?

on line 55 I have add skip list for myself excluding model_ema: skips = ['model_ema.decay', 'model_ema.num_updates'] and key not in skips

robot-never-die commented 1 year ago

Just loaded the models in Model Toolkit extension One of the model is identified as SD-v2 and CLIP is missing its final layer. I just not going to use it. I do want to know if line:55 need to modified all keys with 'model' in it.

robot-never-die commented 1 year ago

It look like model.diffusion_model, first_stage_model, cond_stage_model.transformer.text_model are only modified in weight_matching.py

mod_keys = ['model.diffusion_model',
            'first_stage_model',
            'cond_stage_model.transformer.text_model']

def have_key(k):
    return any([k.startswith(mk) for mk in mod_keys])

Replace if "model" with if have_key(key)

robot-never-die commented 1 year ago

I have an other question about weight_matching function. The method for fp16 and fp32 are different. fp16 loop thru special_layers and fp32 loop thru torch.randperm why the different?

robot-never-die commented 1 year ago

Don't thetas need to be copy, aren't they weight matching with the model original weight? weight_matching(permutation_spec, flatten_params(model_a), theta_0, usefp16=args.usefp16) model_a == theta_0

model_a = torch.load(args.model_a, map_location=device)
model_b = torch.load(args.model_b, map_location=device)
theta_0 = model_a["state_dict"]
theta_1 = model_b["state_dict"]
theta_0 = copy.deepcopy(model_a["state_dict"])
theta_1 = copy.deepcopy(model_b["state_dict"])