kohya-ss / sd-scripts

Apache License 2.0
4.96k stars 835 forks source link

Saving SDXL checkpoints in same format as original (diffusers) fails; finetune checkpoints for SDXL unreadable #981

Open orcinus opened 9 months ago

orcinus commented 9 months ago

Saving finetune checkpoints for SDXL with the "same as source model" format option enabled in the GUI (i.e. save as diffusers) fails with the following exception:

image

AssertionError: key _orig_mod.time_embed.0.weight not found in conversion map

Additionally, saving finetune checkpoints for SDXL in the safetensors format produces checkpoints that are unreadable by ComfyUI either as a checkpoint or unet. They will load in Automatic1111, but produce garbled output.

In short - finetuning SDXL is currently broken for me, because even though training works fine, the produced .safetensors is unreadable/unusable, and saving as diffusers is broken (guessing remapping broken).

Any ideas on what might be wrong?

orcinus commented 9 months ago

Attempting to load either a .safetensors or .ckpt in Comfy as a checkpoint produces this:

File "/Users/ante/Documents/dev/ComfyUI/comfy/sd.py", line 439, in load_checkpoint_guess_config
    model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.", unet_dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/model_detection.py", line 157, in model_config_from_unet
    unet_config = detect_unet_config(state_dict, unet_key_prefix, dtype)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/model_detection.py", line 50, in detect_unet_config
    model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]
                     ~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyError: 'model.diffusion_model.input_blocks.0.0.weight'

Attempting to load a .ckpt or .safetensors in ComfyUI as a unet produces:

  File "/Users/ante/Documents/dev/ComfyUI/nodes.py", line 774, in load_unet
    model = comfy.sd.load_unet(unet_path)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/sd.py", line 513, in load_unet
    model = load_unet_state_dict(sd)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/sd.py", line 490, in load_unet_state_dict
    model_config = model_detection.model_config_from_diffusers_unet(sd, unet_dtype)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/model_detection.py", line 305, in model_config_from_diffusers_unet
    unet_config = unet_config_from_diffusers_unet(state_dict, dtype)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/ante/Documents/dev/ComfyUI/comfy/model_detection.py", line 224, in unet_config_from_diffusers_unet
    match["model_channels"] = state_dict["conv_in.weight"].shape[0]
                              ~~~~~~~~~~^^^^^^^^^^^^^^^^^^
KeyError: 'conv_in.weight'

Original models work just fine in Comfy. Loras work fine too. Anything finetuned, though, is unusable.

orcinus commented 9 months ago

After further testing, it seems checkpoints for finetune are completely broken.

If i load the checkpoint with A1111, it loads the checkpoints, but produces just diffusion noise on inference.

If i load base SDXL first, then switch to the checkpoint, it loads and produces almost the same result as base (very slight differences) and the result does not change with training iterations (no matter which point in training the checkpoint is from, the result is always the same). I'm guessing that's all due to A1111s new diff / model patching on load.

If i load non-SDXL models (e.g. 1.5), then switch to the checkpoint, i get NaN crashes (with or without --no-half-vae and --no-half).

orcinus commented 9 months ago

Tested with final model save too - same thing. So it's not something inherent to incremental checkpoint saves.

orcinus commented 9 months ago

Figured it out - it's caused by torch dynamo. Disabling dynamo creates usable, working checkpoints.

As to why/how... that's a bit harder to figure out...

Bottom line is - it would probably be a good idea to put a note somewhere in the README advising people to disable dynamo if they're using these scripts.