tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
524 stars 28 forks source link

Unable to replicate results for sdv1.5 #42

Closed omrastogi closed 3 months ago

omrastogi commented 3 months ago

I am unable to replicate the quality and speed of DMD sdv1.5. I need help setting the outputs for the corresponding scores mentioned in the paper.

image Table 4 of "One-step Diffusion with Distribution Matching Distillation"

Settings:

GPU: Quadro RTX 8000 python: 3.8.19 cuda: 11.7 torch: 2.0.1

Checkpoint downloaded:

"sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch_fid9.28_checkpoint_model_039000"

Code

import torch
from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Define model and checkpoint
base_model_id = "runwayml/stable-diffusion-v1-5"
ckpt_name = "sdv1.5/pytorch_model.bin"

# Load model
unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(torch.load(ckpt_name, map_location="cuda"))

# Create diffusion pipeline
pipe = DiffusionPipeline.from_pretrained(
    base_model_id,
    unet=unet,
    torch_dtype=torch.float16,
    variant="fp16"
).to("cuda")

# Configure scheduler
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

# Set prompt
prompt = "an image of a beautiful girl sitting in a dandelion field"

# Ensure the pipeline is in float32 to disable NSFW
pipe = pipe.to(dtype=torch.float32)
pipe.safety_checker = None  # Disable the safety checker to allow NSFW content

import time

# Measure inference time
start_time = time.time()
# LCMScheduler's default timesteps differ from the one we used for training 
image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0).images[0]
end_time = time.time()

# Print time taken for image generation
print(f"Time taken for image generation: {end_time - start_time:.2f} seconds")

# Save generated image
image.save("generated_onestep.png")

Output in 0.7 seconds:

image

For comparison, Lora LCM sdv1.5 4 steps, takes 0.6 seconds in the same settings.

generated_lcm

tianweiy commented 3 months ago

thank you for the interest. But I don't get what results you can't reproduce? quality / speed or other things? Could you specify?

Please note that this repo is not a code/model release for dmd 1. The sdv1.5 model in dmd2 is trained with guidance 1.75 and it is not supposed to give good images. For higher image quality, you should try the sdxl

omrastogi commented 3 months ago

Thank you for the clarification. I had assumed that the sdv1.5 would match or surpass the results mentioned in dmd1, as this is an improvement over it.

I am looking for a one-step sdv1.5, that could give good quality images.

tianweiy commented 3 months ago

I mean it is an improvement but we are not comparing at the exact same setting (they differ in terms of guidance scale used during the training). unfortunately dmd2 repo doesn't have a one-step sdv1.5 that supports high image quality. To get this, we need to train a model with higher guidance scale. We didn't try this as sdxl is just way better in that high guidance regime

omrastogi commented 3 months ago

Got it, thanks for making it clear.

Oh, btw this is amazing work @tianweiy