tianweiy / DMD2

Other
331 stars 20 forks source link

Convert 4-step unet into a checkpoint file #20

Open coastbluet opened 3 weeks ago

coastbluet commented 3 weeks ago

Sorry if this does not make sense - I write this with very little understanding of stable diffusion. Would be possible for you to compile your unet for sdxl base into a checkpoint file (ckpt or safetensors)? As I understand it, this would allow me to use it within the Krita Diffusion plugin.

I've tested the 4-step unet comfyui workflow, great performance and results, many thanks.

tianweiy commented 3 weeks ago

I uploaded a safe tensor file here https://huggingface.co/tianweiy/DMD2/blob/main/dmd2_sdxl_4step_unet_fp16.safetensors

coastbluet commented 3 weeks ago

Many thanks for doing this so quickly. However, I'm having difficulty getting it to work. I tried with both a local installation and on google colab. I get the same error with either, using the default ComfyUI workflow:

Error occurred when executing CheckpointLoaderSimple:

'model.diffusion_model.input_blocks.0.0.weight'

File "/content/drive/MyDrive/ComfyUI/execution.py", line 151, in recursive_execute output_data, output_ui = get_output_data(obj, input_data_all) File "/content/drive/MyDrive/ComfyUI/execution.py", line 81, in get_output_data return_values = map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True) File "/content/drive/MyDrive/ComfyUI/execution.py", line 74, in map_node_over_list results.append(getattr(obj, func)(**slice_dict(input_data_all, i))) File "/content/drive/MyDrive/ComfyUI/nodes.py", line 516, in load_checkpoint out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) File "/content/drive/MyDrive/ComfyUI/comfy/sd.py", line 476, in load_checkpoint_guess_config model_config = model_detection.model_config_from_unet(sd, "model.diffusion_model.") File "/content/drive/MyDrive/ComfyUI/comfy/model_detection.py", line 232, in model_config_from_unet unet_config = detect_unet_config(state_dict, unet_key_prefix) File "/content/drive/MyDrive/ComfyUI/comfy/model_detection.py", line 113, in detect_unet_config model_channels = state_dict['{}input_blocks.0.0.weight'.format(key_prefix)].shape[0]

Any insight would be appreciated, even just to confirm it does work for you or others and it is just my set up. Thank you very much

Screenshot 2024-06-12 104319

tianweiy commented 3 weeks ago

oh, i don't have the comfyui. i just use a function to resave the original checkpoint with safetensors format. I guess it might have different parameter naming. I will check and reupload tmr morning

coastbluet commented 3 weeks ago

Thanks very much. Does it need to be packaged with the VAE and clip from the sdxl base?

Just to explain why I am looking for this: I did try your lora version, which works well in krita diffusion and is only slightly slower than using the unet. But to me the images produced were slightly different and didn't have the quality I originally liked from your unet - it's probably a very subjective thing. Thank you for your time on this

tianweiy commented 3 weeks ago

I see. so I think those few ui requires specific key:value pair in addition to the safe tensors format (from the error message, it is probably not the van thing). I am busy with a few other things at the moment but I will write a conversion script as soon as I can.

BTW, what is the exact (or the few) ui that you want to use this file? Could you give me the link to the repo so that I can test ?

coastbluet commented 3 weeks ago

That would be great, thank you. The UI I would like to use is https://github.com/Acly/krita-ai-diffusion

However I don't think it has any special requirements that are different to ComfyUI: it is a plug in for the program Krita, it uses ComfyUI (it installs it for you), and all it requires is that the checkpoint can be loaded as a single file with this simple 'Load Checkpoint' node, as opposed to the two files (the sdxl base plus your .bin file), loaded using the Load Checkpoint node together with UNETLoader: image image