replicate / cog-sdxl

Stable Diffusion XL training and inference as a cog model
https://replicate.com/stability-ai/sdxl
Apache License 2.0
203 stars 106 forks source link

Modifying the code of a blog post #37

Open oshita-n opened 10 months ago

oshita-n commented 10 months ago

The lora model does not load when run with the code in the following blog.

blog: https://replicate.com/blog/fine-tune-sdxl

When I changed the lora model loading to the load_attn_procs method as shown below, the model loaded correctly.

import torch
from diffusers import DiffusionPipeline

from safetensors import safe_open
# need this command : git clone https://github.com/replicate/cog-sdxl cog_sdxl
from cog_sdxl.dataset_and_utils import TokenEmbeddingsHandler
from diffusers.models import AutoencoderKL

pipe = DiffusionPipeline.from_pretrained(
        "stabilityai/stable-diffusion-xl-base-1.0",
        torch_dtype=torch.float16,
        variant="fp16"
).to("cuda")

pipe.unet.load_attn_procs("/content/lora.safetensors") # should take < 2 seconds

text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]

embhandler = TokenEmbeddingsHandler(text_encoders, tokenizers)
embhandler.load_embeddings("/content/embeddings.pti")

Thank you for the great service and easy fine tuning of SDXL.