google / style-aligned

Official code for "Style Aligned Image Generation via Shared Attention"
Apache License 2.0
1.23k stars 91 forks source link

Crash Occurs When Loading LoRA Weights with load_lora_weights #16

Open pedropaf opened 11 months ago

pedropaf commented 11 months ago

I am experiencing a crash after running 50 steps of inference when attempting to load LoRA (Low-Rank Adaptation) weights into SDXL model using the load_lora_weights method. I have 16gb of VRAM and I don't see that more than 9gb are used. This is the modification that I've done:

def style_aligned_sdxl(initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5, style_prompt, lora_checkpoint): 
    try:
        # Load LoRA weights if provided
        if lora_checkpoint:
            pipeline.load_lora_weights(checkpoint_path) # kohya trained lora

        # Combine the style prompt with each initial prompt
        sets_of_prompts = [prompt + ". " + style_prompt for prompt in [initial_prompt1, initial_prompt2, initial_prompt3, initial_prompt4, initial_prompt5]]
        # Generate images using the pipeline
        images = pipeline(sets_of_prompts).images
        return images
    except Exception as e:
        raise gr.Error(f"Error in generating images: {e}")

If I don't load the Lora, I can generate images successfully. Also, I can run inference successfully if I load the Lora without the sa_handler. Are you able to help with this?

Thanks!

pedropaf commented 11 months ago

I can see that Dreambooth Lora is supported and explained in the paper but not shared in the code samples. Could you please share the code to use the pipeline together with a DB Lora? šŸ™šŸ»