d8ahazard / sd_dreambooth_extension

Other
1.86k stars 281 forks source link

Fix sample generation during Lora Training #1336

Closed ignis-sec closed 12 months ago

ignis-sec commented 1 year ago

Disclaimer: image

I'm not at all familiar with diffusers, or with this codebase. There is probably a faster approach to this but hey, its only sampling and it works, so better than nothing.

Issue

Since late march, sample image generation during training will not work if a Lora network is being trained, and instead it will output random noise. More discussion on the issue available in issue tracker, at issue #1273.

The problem, in depth.

For sample generation, s_pipeline is being saved to a temporary directory and a new DiffusionPipeline is being instantiated from the checkpoint (I'm assuming its just a quick way of releasing references to unet and text_encoder being trained)

However, these models are structurally different from UNet2DConditionModel and CLIPTextModel because of the following code:

injectable_lora = get_target_module("injection", args.use_lora_extended)
            target_module = get_target_module("module", args.use_lora_extended)

            unet_lora_params, _ = injectable_lora(
                unet,
                r=args.lora_unet_rank,
                loras=lora_path,
                target_replace_module=target_module,
            )

When a DiffusionPipeline is being created via the from_pretrained method, if low_cpu_mem_usage=False and device_map=None parameters are passed, it will randomly initialize the missing required state keys.

Because of this, inference will output random noise.

The solution, in depth.

Solution in this case contains the following steps: 1- Before saving the existing DiffusionPipeline to disk, use the save_pipe function to save lora weights to disk. 2- If a Lora network is being trained, ignore the weight_dtype when constructing the DiffusionPipeline at first.^1 3- If a Lora network is being trained:

Other fixes included, and hardest "Where is Waldo" game ever

I've also taken the liberty of tracking an extra comma that has broken save_pipe function (which was not being used, now it is). Extra comma was converting the lora txt filename to tuple and was throwing an exception when passed to torch.save/safetensors.torch.save. Additionally, I've modified the lora text encoder file name search in patch_pipe to use _text_lora_path(unet_path) instead of _text_lora_path_ui(unet_path). Couldn't see this symbol used anywhere else as well, so i figured it would be nice to use together in this case with save_pipe in same naming convention. Lmk if i missed something and broke something else.

Checklist before requesting a review

ignis-sec commented 1 year ago

Actually, I just checked and patch_pipe is being used by image builder which is used for sample generation while not in training. However, its checking maybe_unet_path.endswith(".text_encoder.pt") which is _text_lora_path naming while it was using _text_lora_path_ui, which would be maybe_unet_path.endswith("_txt.pt").

I'm a bit confused on if i should revert that and instead change and modify save_pipe to use _text_lora_path_ui instead (so it can still work with patch_pipe/merge_loras_to_pipe)

In one of these cases text encoding weights from lora will not be added to the diffusion network when generating samples with ImageBuilder, and I'm not exactly sure if this fixes that as well or breaks it. Let me know and I'll modify.

ignis-sec commented 11 months ago

Bumping as the issue is not yet fixed