google / style-aligned

Official code for "Style Aligned Image Generation via Shared Attention"
Apache License 2.0
1.19k stars 86 forks source link

Alternative implementation in Refiners #25

Open Laurent2916 opened 6 months ago

Laurent2916 commented 6 months ago

We are building Refiners, an open source, PyTorch-based micro-framework made to easily train and run adapters on top of foundational models. Just wanted to let you know that StyleAligned is now supported in Refiners! (congrats on the great work, by the way!!)

E.g. an equivalent to the style_aligned_sdxl.ipynb notebook (or the demo_stylealigned_sdxl.py demo) would look like this:

  1. Follow these install steps
  2. Run the code snippet below which gives:
    
    from pathlib import Path

import torch from PIL import Image

from refiners.fluxion.utils import manual_seed, no_grad from refiners.foundationals.latent_diffusion import StableDiffusion_XL, StyleAlignedAdapter

initialize Stable Diffusion XL model

sdxl = StableDiffusion_XL( device=torch.device("cuda"), dtype=torch.float16, )

load weights

sdxl_folder_weights = Path(...) sdxl.lda.load_from_safetensors(sdxl_folder_weights / "sdxl-lda-fp16-fix.safetensors") sdxl.unet.load_from_safetensors(sdxl_folder_weights / "sdxl-unet.safetensors") sdxl.clip_text_encoder.load_from_safetensors(sdxl_folder_weights / "DoubleCLIPTextEncoder.safetensors")

inject style aligned adapter

style_aligned_adapter = StyleAlignedAdapter(sdxl.unet) style_aligned_adapter.inject()

set_of_prompts = [ "a toy train. macro photo. 3d game asset", "a toy airplane. macro photo. 3d game asset", "a toy bicycle. macro photo. 3d game asset", "a toy car. macro photo. 3d game asset", "a toy boat. macro photo. 3d game asset", ]

with no_grad():

create (context) embeddings from prompts

clip_text_embedding, pooled_text_embedding = sdxl.compute_clip_text_embedding(
    text=set_of_prompts,
    negative_text=[""] * len(set_of_prompts),
)

time_ids = sdxl.default_time_ids.repeat(len(set_of_prompts), 1)

# initialize latents
manual_seed(seed=2)
x = torch.randn(
    (len(set_of_prompts), 4, 128, 128),
    device=sdxl.device,
    dtype=sdxl.dtype,
)

# denoise
for step in sdxl.steps:
    x = sdxl(
        x,
        step=step,
        clip_text_embedding=clip_text_embedding,
        pooled_text_embedding=pooled_text_embedding,
        time_ids=time_ids,
    )

# decode latents
predicted_images = [sdxl.lda.decode_latents(latent.unsqueeze(0)) for latent in x]

tile all images horizontally

merged_image = Image.new("RGB", (1024 len(predicted_images), 1024)) for i in range(len(predicted_images)): merged_image.paste(predicted_images[i], (i 1024, 0)) merged_image.save("style_aligned_example.png")



A few more things:
- We support both SDXL and SD1.5
- This new implementation builds upon Refiners’ [Adapter API](https://refine.rs/concepts/adapter/)
- Adapters can be composed, e.g. you may freely combine a [StyleAlignedAdapter](https://refine.rs/reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.style_aligned.StyleAlignedAdapter) with a [ControlLoraAdapter](https://refine.rs/reference/foundationals/latent_diffusion/#refiners.foundationals.latent_diffusion.stable_diffusion_xl.ControlLoraAdapter)
- We have written a blog post with additional details about this implementation: https://blog.finegrain.ai/posts/implementing-style-aligned/

Feedback welcome!