Panchovix / stable-diffusion-webui-reForge

GNU Affero General Public License v3.0
282 stars 10 forks source link

[Feature Request]: Implementing lora-ctl #36

Open altoiddealer opened 1 month ago

altoiddealer commented 1 month ago

Is there an existing issue for this?

What would your feature do ?

lllyasviel/stable-diffusion-webui-forge/issues/253 lllyasviel/stable-diffusion-webui-forge/issues/68

This is likely not a simple feature addition...

There is a very super cool extension called sd-webui-loractl which you can read about in the linked Issues from Forge Main... it allows LORA weights to be calculated on each step, with a syntax that can be used in the prompt to ramp the weights up or down.

It's quite amazin, really the only feature I still use exclusively in A1111.

lllyasviel had marked this Issue with the High Priority flag before going unresponsive for Forge development.

The author of the loractl extension had analyzed the situation and provided key details on what was missing / what would be needed to make it work. They also seemed very responsive and open to answering any questions, maybe even pulling some weight to help get it working in Forge.

author's comments additional comment in this thread

If you start looking for something ambitious to add in to this project... here you go! :)

Proposed workflow

  1. Go to ....
  2. Press ....
  3. ...

Additional information

No response

Panchovix commented 1 month ago

Hi there, a lot of people have been asking for this, so it sounds very interesting.

But I'm not sure where to start. Sorry for the ping @cheald, not sure if you heard about this repo before, but do you think the changes commented on https://github.com/lllyasviel/stable-diffusion-webui-forge/issues/68#issuecomment-1945491969 would be applied as well? (dev_upstream branch probably)

Or about the branch of the extension for forge-reforge? I could help if needed

cheald commented 1 month ago

Okay, so model weights are a combination of the base model + deltas computed from a LoRA (or multiples). There is some procedure which takes a LoRA, loads its state dict, figures out which keys map to which model keys, perform any matrix multiplications necessary to derive the full weight delta matrices, and then finally adds a fraction of those weights (via the strength parameter) to the underlying weights, with the final sum being used for inference.

The concept is that one each step, new weights are computed. This is...well, not high performance in the A1111 version, because it basically just invalidates the A1111 lora cache so that new weights are recomputed with each step. Last I looked, one of the things that Forge did was precompute the final weights before inference started, which made the A1111 loractl implementation a non-starter (since it worked via cache invalidation during the inference loop).

A more clever implementation could do something like:

  1. Load the lora and retain the unmodified weights
  2. For each step, accept a base weight, a lora weight, a previous step weight modifier (or None, if the first step), and a next step weight modifier
  3. Subtract (lora weights previous weight) from the base weights (if previous_weight is not None), and then add (lora_weights next_weight). You can skip this step if previous_weight == next_weight.
  4. Return the final weights as the weights to use for that layer.
def apply_step_lora(lora_weights, next_weight, previous_weight):
  return base_weights - (previous_weight * lora_weights) + (next_weight * lora_weights)

Once you have that, then all you need is some way to say "linearly interpolate from this weight to that weight over these steps", and it's trivial to compute the lora weight at any given step.

This would run massively faster than the simple cache-bust option (which results in the base weights being copied back from cache and then loras reapplied on top of them), but would require that the lora weights be loaded and kept available during inference, and that there be a mechanism within the inference loop that would have the opportunity to modify the model weights used before each step. Additionally, text encoder conditionings would need to be recomputed if any text encoder weights are modified by this procedure.

I haven't looked into the lora implementation in this repo, but if I were building a loractl-like feature in natively, something like the above is where I'd start.

Panchovix commented 1 month ago

Many thanks! Warning, long post ahead; LoRAs on reForge are loaded like this:

On ldm_patched/modules/sd.py, we define def load_lora_for_models, which it's

def load_lora_for_models(model, clip, lora, strength_model, strength_clip, filename='default'):
    model_flag = type(model.model).__name__ if model is not None else 'default'
    unet_keys = ldm_patched.modules.lora.model_lora_keys_unet(model.model) if model is not None else {}
    clip_keys = ldm_patched.modules.lora.model_lora_keys_clip(clip.cond_stage_model) if clip is not None else {}
    lora_unmatch = lora
    lora_unet, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, unet_keys)
    lora_clip, lora_unmatch = ldm_patched.modules.lora.load_lora(lora_unmatch, clip_keys)

    if lora_unmatch:
        print(f'[LORA] Loading {filename} for {model_flag} with unmatched keys {list(lora_unmatch.keys())}')

    new_model = model.clone() if model is not None else None
    new_clip = clip.clone() if clip is not None else None

    if new_model is not None and lora_unet:
        loaded_keys = new_model.add_patches(lora_unet, strength_model)
        skipped_keys = [item for item in lora_unet if item not in loaded_keys]
        print(f'[LORA] Loaded {filename} for {model_flag}-UNet with {len(loaded_keys)} keys at weight {strength_model} (skipped {len(skipped_keys)} keys)')
        model = new_model

    if new_clip is not None and lora_clip:
        loaded_keys = new_clip.add_patches(lora_clip, strength_clip)
        skipped_keys = [item for item in lora_clip if item not in loaded_keys]
        print(f'[LORA] Loaded {filename} for {model_flag}-CLIP with {len(loaded_keys)} keys at weight {strength_clip} (skipped {len(skipped_keys)} keys)')
        clip = new_clip

    return model, clip

Then on modules/networks.py, we apply the lora logic

def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
    global lora_state_dict_cache

    current_sd = sd_models.model_data.get_sd_model()
    if current_sd is None:
        return

    loaded_networks.clear()

    networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]
    if any(x is None for x in networks_on_disk):
        list_available_networks()
        networks_on_disk = [available_networks.get(name, None) if name.lower() in forbidden_network_aliases else available_network_aliases.get(name, None) for name in names]

    failed_to_load_networks = []

    for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
        try:
            net = load_network(name, network_on_disk)
        except Exception as e:
            failed_to_load_networks.append(name)
            logging.info(f"Couldn't find network with name {name}")
            if network_on_disk is not None:
                errors.display(e, f"loading network {network_on_disk.filename}")
            continue
        net.mentioned_name = name
        network_on_disk.read_hash()
        loaded_networks.append(net)

    if failed_to_load_networks:
        sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))

        lora_not_found_message = f'Lora not found: {", ".join(failed_to_load_networks)}'
        if shared.opts.lora_not_found_warning_console:
            print(f'\n{lora_not_found_message}\n')
        if shared.opts.lora_not_found_gradio_warning:
            gr.Warning(lora_not_found_message)

    compiled_lora_targets = []
    for a, b, c in zip(networks_on_disk, unet_multipliers, te_multipliers):
        if a is not None:
            compiled_lora_targets.append([a.filename, b, c])

    compiled_lora_targets_hash = str(compiled_lora_targets)

    if current_sd.current_lora_hash == compiled_lora_targets_hash:
        return

    current_sd.current_lora_hash = compiled_lora_targets_hash
    current_sd.forge_objects.unet = current_sd.forge_objects_original.unet
    current_sd.forge_objects.clip = current_sd.forge_objects_original.clip

    for filename, strength_model, strength_clip in compiled_lora_targets:
        lora_sd = load_lora_state_dict(filename)
        current_sd.forge_objects.unet, current_sd.forge_objects.clip = load_lora_for_models(
            current_sd.forge_objects.unet, current_sd.forge_objects.clip, lora_sd, strength_model, strength_clip,
            filename=filename)

    current_sd.forge_objects_after_applying_lora = current_sd.forge_objects.shallow_copy()
    return

Functions that sd.py use are model_lora_keys_unet, model_lora_keys_clip and load_lora

model_lora_keys_unet on ldm_patched.modules.lora is

def model_lora_keys_unet(model, key_map={}):
    sd = model.state_dict()
    sdk = sd.keys()

    for k in sdk:
        if k.startswith("diffusion_model.") and k.endswith(".weight"):
            key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
            key_map["lora_unet_{}".format(key_lora)] = k
            key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config

    diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.model_config.unet_config)
    for k in diffusers_keys:
        if k.endswith(".weight"):
            unet_key = "diffusion_model.{}".format(diffusers_keys[k])
            key_lora = k[:-len(".weight")].replace(".", "_")
            key_map["lora_unet_{}".format(key_lora)] = unet_key

            diffusers_lora_prefix = ["", "unet."]
            for p in diffusers_lora_prefix:
                diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
                if diffusers_lora_key.endswith(".to_out.0"):
                    diffusers_lora_key = diffusers_lora_key[:-2]
                key_map[diffusers_lora_key] = unet_key

    if isinstance(model, ldm_patched.modules.model_base.SD3): #Diffusers lora SD3
        diffusers_keys = ldm_patched.modules.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
        for k in diffusers_keys:
            if k.endswith(".weight"):
                to = diffusers_keys[k]
                key_lora = "transformer.{}".format(k[:-len(".weight")]) #regular diffusers sd3 lora format
                key_map[key_lora] = to

                key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #format for flash-sd3 lora and others?
                key_map[key_lora] = to

                key_lora = "lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_")) #OneTrainer lora
                key_map[key_lora] = to

    return key_map

model_lora_keys_clip on this same file is

def model_lora_keys_clip(model, key_map={}):
    sdk = model.state_dict().keys()

    text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
    clip_l_present = False
    for b in range(32): #TODO: clean up
        for c in LORA_CLIP_MAP:
            k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
            if k in sdk:
                lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
                key_map[lora_key] = k
                lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
                key_map[lora_key] = k
                lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
                key_map[lora_key] = k

            k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
            if k in sdk:
                lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
                key_map[lora_key] = k
                lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
                key_map[lora_key] = k
                clip_l_present = True
                lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
                key_map[lora_key] = k

            k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
            if k in sdk:
                if clip_l_present:
                    lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
                    key_map[lora_key] = k
                    lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
                    key_map[lora_key] = k
                else:
                    lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
                    key_map[lora_key] = k
                    lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
                    key_map[lora_key] = k
                    lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
                    key_map[lora_key] = k

    for k in sdk: #OneTrainer SD3 lora
        if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
            l_key = k[len("t5xxl.transformer."):-len(".weight")]
            lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
            key_map[lora_key] = k

    k = "clip_g.transformer.text_projection.weight"
    if k in sdk:
        key_map["lora_prior_te_text_projection"] = k #cascade lora?
        # key_map["text_encoder.text_projection"] = k #TODO: check if other lora have the text_projection too
        key_map["lora_te2_text_projection"] = k #OneTrainer SD3 lora

    k = "clip_l.transformer.text_projection.weight"
    if k in sdk:
        key_map["lora_te1_text_projection"] = k #OneTrainer SD3 lora, not necessary but omits warning

    return key_map

def load_lora in this same file is

def load_lora(lora, to_load):
    patch_dict = {}
    loaded_keys = set()
    for x in to_load:
        alpha_name = "{}.alpha".format(x)
        alpha = None
        if alpha_name in lora.keys():
            alpha = lora[alpha_name].item()
            loaded_keys.add(alpha_name)

        dora_scale_name = "{}.dora_scale".format(x)
        dora_scale = None
        if dora_scale_name in lora.keys():
            dora_scale = lora[dora_scale_name]
            loaded_keys.add(dora_scale_name)

        regular_lora = "{}.lora_up.weight".format(x)
        diffusers_lora = "{}_lora.up.weight".format(x)
        diffusers2_lora = "{}.lora_B.weight".format(x)
        diffusers3_lora = "{}.lora.up.weight".format(x)
        transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
        A_name = None

        if regular_lora in lora.keys():
            A_name = regular_lora
            B_name = "{}.lora_down.weight".format(x)
            mid_name = "{}.lora_mid.weight".format(x)
        elif diffusers_lora in lora.keys():
            A_name = diffusers_lora
            B_name = "{}_lora.down.weight".format(x)
            mid_name = None
        elif diffusers2_lora in lora.keys():
            A_name = diffusers2_lora
            B_name = "{}.lora_A.weight".format(x)
            mid_name = None
        elif diffusers3_lora in lora.keys():
            A_name = diffusers3_lora
            B_name = "{}.lora.down.weight".format(x)
            mid_name = None
        elif transformers_lora in lora.keys():
            A_name = transformers_lora
            B_name ="{}.lora_linear_layer.down.weight".format(x)
            mid_name = None

        if A_name is not None:
            mid = None
            if mid_name is not None and mid_name in lora.keys():
                mid = lora[mid_name]
                loaded_keys.add(mid_name)
            patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
            loaded_keys.add(A_name)
            loaded_keys.add(B_name)

        ######## loha
        hada_w1_a_name = "{}.hada_w1_a".format(x)
        hada_w1_b_name = "{}.hada_w1_b".format(x)
        hada_w2_a_name = "{}.hada_w2_a".format(x)
        hada_w2_b_name = "{}.hada_w2_b".format(x)
        hada_t1_name = "{}.hada_t1".format(x)
        hada_t2_name = "{}.hada_t2".format(x)
        if hada_w1_a_name in lora.keys():
            hada_t1 = None
            hada_t2 = None
            if hada_t1_name in lora.keys():
                hada_t1 = lora[hada_t1_name]
                hada_t2 = lora[hada_t2_name]
                loaded_keys.add(hada_t1_name)
                loaded_keys.add(hada_t2_name)

            patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
            loaded_keys.add(hada_w1_a_name)
            loaded_keys.add(hada_w1_b_name)
            loaded_keys.add(hada_w2_a_name)
            loaded_keys.add(hada_w2_b_name)

        ######## lokr
        lokr_w1_name = "{}.lokr_w1".format(x)
        lokr_w2_name = "{}.lokr_w2".format(x)
        lokr_w1_a_name = "{}.lokr_w1_a".format(x)
        lokr_w1_b_name = "{}.lokr_w1_b".format(x)
        lokr_t2_name = "{}.lokr_t2".format(x)
        lokr_w2_a_name = "{}.lokr_w2_a".format(x)
        lokr_w2_b_name = "{}.lokr_w2_b".format(x)

        lokr_w1 = None
        if lokr_w1_name in lora.keys():
            lokr_w1 = lora[lokr_w1_name]
            loaded_keys.add(lokr_w1_name)

        lokr_w2 = None
        if lokr_w2_name in lora.keys():
            lokr_w2 = lora[lokr_w2_name]
            loaded_keys.add(lokr_w2_name)

        lokr_w1_a = None
        if lokr_w1_a_name in lora.keys():
            lokr_w1_a = lora[lokr_w1_a_name]
            loaded_keys.add(lokr_w1_a_name)

        lokr_w1_b = None
        if lokr_w1_b_name in lora.keys():
            lokr_w1_b = lora[lokr_w1_b_name]
            loaded_keys.add(lokr_w1_b_name)

        lokr_w2_a = None
        if lokr_w2_a_name in lora.keys():
            lokr_w2_a = lora[lokr_w2_a_name]
            loaded_keys.add(lokr_w2_a_name)

        lokr_w2_b = None
        if lokr_w2_b_name in lora.keys():
            lokr_w2_b = lora[lokr_w2_b_name]
            loaded_keys.add(lokr_w2_b_name)

        lokr_t2 = None
        if lokr_t2_name in lora.keys():
            lokr_t2 = lora[lokr_t2_name]
            loaded_keys.add(lokr_t2_name)

        if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
            patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))

        #glora
        a1_name = "{}.a1.weight".format(x)
        a2_name = "{}.a2.weight".format(x)
        b1_name = "{}.b1.weight".format(x)
        b2_name = "{}.b2.weight".format(x)
        if a1_name in lora:
            patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
            loaded_keys.add(a1_name)
            loaded_keys.add(a2_name)
            loaded_keys.add(b1_name)
            loaded_keys.add(b2_name)

        w_norm_name = "{}.w_norm".format(x)
        b_norm_name = "{}.b_norm".format(x)
        w_norm = lora.get(w_norm_name, None)
        b_norm = lora.get(b_norm_name, None)

        if w_norm is not None:
            loaded_keys.add(w_norm_name)
            patch_dict[to_load[x]] = ("diff", (w_norm,))
            if b_norm is not None:
                loaded_keys.add(b_norm_name)
                patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))

        diff_name = "{}.diff".format(x)
        diff_weight = lora.get(diff_name, None)
        if diff_weight is not None:
            patch_dict[to_load[x]] = ("diff", (diff_weight,))
            loaded_keys.add(diff_name)

        diff_bias_name = "{}.diff_b".format(x)
        diff_bias = lora.get(diff_bias_name, None)
        if diff_bias is not None:
            patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
            loaded_keys.add(diff_bias_name)

    remaining_dict = {x: y for x, y in lora.items() if x not in loaded_keys}
    return patch_dict, remaining_dict

On networks.py we have

def load_lora_state_dict(filename):
    return load_torch_file(filename, safe_load=True)

load_torch_file is on ldm_patched.modules.utils

def load_torch_file(ckpt, safe_load=False, device=None):
    if device is None:
        device = torch.device("cpu")
    if ckpt.lower().endswith(".safetensors"):
        sd = safetensors.torch.load_file(ckpt, device=device.type)
    else:
        if safe_load:
            if not 'weights_only' in torch.load.__code__.co_varnames:
                logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
                safe_load = False
        if safe_load:
            pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
        else:
            pl_sd = torch.load(ckpt, map_location=device, pickle_module=ldm_patched.modules.checkpoint_pickle)
        if "global_step" in pl_sd:
            logging.debug(f"Global Step: {pl_sd['global_step']}")
        if "state_dict" in pl_sd:
            sd = pl_sd["state_dict"]
        else:
            sd = pl_sd
    return sd

That's the complete process of how we load loras. This is for the dev_upstream branch which has comfy backend upstream changes. Main branch is a bit different and more akind to original Forge (which is basically comfy backend from 7 months ago)

With all this information, do you we think we can reach something?

cheald commented 1 month ago

Okay, so the process right now is:

  1. Determine the list of lora files on disk and the multipliers for each
  2. Iterate through each lora, load it from disk, and add a patch to the ModelPatcher for each update to be applied
  3. Return the final weights and then proceed with inference.

Instead, you'd need a way to remove patches from the ModelPatcher and replace them with reweighted copies on each step, then repatch the model and use that set of patched weights for the next inference step.

The model is patched during LoadedModel.model_load: https://github.com/Panchovix/stable-diffusion-webui-reForge/blob/cdb1d84da3c0027b0bb228907b6846a4a473f5ea/ldm_patched/modules/model_management.py#L328

You'll have to trace back where this is used, and add a mechanism to alter the patch weights and re-patch the model per step somewhere inside the inference loop before each unet/text encoder call.

This is likely not a trivial process (which is why I haven't implemented loractl in forge/comfyui!), but I encourage you to give it a go!

sashasubbbb commented 1 month ago

Slightly off-topic, but there is a similar ComfyUI node, that allows for dynamic lora weight adjustment. https://github.com/asagi4/comfyui-prompt-control Perhaps it would be easier to implement that? Just a suggestion.

Panchovix commented 1 month ago

Slightly off-topic, but there is a similar ComfyUI node, that allows for dynamic lora weight adjustment. https://github.com/asagi4/comfyui-prompt-control Perhaps it would be easier to implement that? Just a suggestion.

Pretty interesting, something as how it's implemented there should in theory work (the logic I mean), but the extension does it for Comfy. So something like that but with somehow an UI for reForge would maybe work?