Birch-san / sdxl-play

BSD 3-Clause "New" or "Revised" License
25 stars 2 forks source link

Quality comparison for vanilla SDXL, SD interposed, and SD interposed+consistency decoded quality #4

Closed williamberman closed 11 months ago

williamberman commented 11 months ago

Hello! We saw your tweet https://twitter.com/Birchlabs/status/1721709378691010884 and are looking at if it makes sense to add the interpose model to diffusers. I'm getting inconclusive quality results where it looks to me like the results are either inconclusive or the vanilla sdxl decoder might give the best decoding quality. I was wondering if you have any good examples and I just got unlucky or if maybe I'm using the interposer wrong. Any thoughts appreciated!

Vanilla SDXL -> Interposed + Vanilla SD VAE -> Interposed + Consistency SD VAE

With 20 denoising steps: sdxl_image_horse sd_image_horse sd_con_image_horse

sdxl_image_a very detailed dragon sd_image_a very detailed dragon sd_con_image_a very detailed dragon

sdxl_image_a fairy tale castle sd_image_a fairy tale castle sd_con_image_a fairy tale castle

With 50 denoising steps: sdxl_image_horse sd_image_horse sd_con_image_horse

sdxl_image_a very detailed dragon sd_image_a very detailed dragon sd_con_image_a very detailed dragon

sdxl_image_a fairy tale castle sd_image_a fairy tale castle sd_con_image_a fairy tale castle

And here's the script I used

from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
import torch
from PIL import Image
import ipdb

torch.set_grad_enabled(False)

import os
import torch.nn as nn
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

# from city96's
# https://github.com/city96/SD-Latent-Interposer/blob/bf5cec6eb46fd2d714842e60a4e7dd5a18eb174f/interposer.py
# Apache-licensed
class Interposer(nn.Module):
    """
        Basic NN layout, ported from:
        https://github.com/city96/SD-Latent-Interposer/blob/main/interposer.py
    """
    version = 3.1 # network revision
    def __init__(self):
        super().__init__()
        self.chan = 4
        self.hid = 128

        self.head_join  = nn.ReLU()
        self.head_short = nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1)
        self.head_long  = nn.Sequential(
            nn.Conv2d(self.chan, self.hid, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(self.hid,  self.hid, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(self.hid,  self.hid, kernel_size=3, stride=1, padding=1),
        )
        self.core = nn.Sequential(
            Block(self.hid),
            Block(self.hid),
            Block(self.hid),
        )
        self.tail = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.hid, self.chan, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        y = self.head_join(
            self.head_long(x)+
            self.head_short(x)
        )
        z = self.core(y)
        return self.tail(z)

class Block(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.join = nn.ReLU()
        self.long = nn.Sequential(
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.1),
            nn.Conv2d(size, size, kernel_size=3, stride=1, padding=1),
        )
    def forward(self, x):
        y = self.long(x)
        z = self.join(y + x)
        return z

class LatentInterposer:
    def __init__(self):
        pass

    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT", ),
                "latent_src": (["v1", "xl"],),
                "latent_dst": (["v1", "xl"],),
            }
        }

    RETURN_TYPES = ("LATENT",)
    FUNCTION = "convert"
    CATEGORY = "latent"

    def convert(self, samples, latent_src, latent_dst):
        if latent_src == latent_dst:
            return (samples,)
        model = Interposer()
        model.eval()
        filename = f"{latent_src}-to-{latent_dst}_interposer-v{model.version}.safetensors"
        local = os.path.join(
            os.path.join(os.path.dirname(os.path.realpath(__file__)),"models"),
            filename
        )

        if os.path.isfile(local):
            print("LatentInterposer: Using local model")
            weights = local
        else:
            print("LatentInterposer: Using HF Hub model")
            weights = str(hf_hub_download(
                repo_id="city96/SD-Latent-Interposer",
                filename=filename)
            )

        model.load_state_dict(load_file(weights))
        lt = samples["samples"]
        lt = model(lt)
        del model
        return ({"samples": lt},)

NODE_CLASS_MAPPINGS = {
    "LatentInterposer": LatentInterposer,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "LatentInterposer": "Latent Interposer"
}

interposer = Interposer()
interposer.eval()
interposer_weights = str(hf_hub_download(
  repo_id='city96/SD-Latent-Interposer',
  filename='xl-to-v1_interposer-v3.1.safetensors')
)
interposer.load_state_dict(load_file(interposer_weights))
interposer.to('cuda')
interposer.to(torch.float16)

from diffusers import DiffusionPipeline, AutoencoderKL
import torch
from diffusers import ConsistencyDecoderVAE

consistency_decoder_vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
consistency_decoder_vae.cuda()
sdxl_vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)

sdxl_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant='fp16', safety_checker=None, vae=sdxl_vae)
sdxl_pipe.to('cuda')

sd_pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16", safety_checker=None)
sd_pipe.to('cuda')

prompts = ["horse", "a very detailed dragon", "a fairy tale castle"]

for prompt in prompts:
    out = sdxl_pipe(prompt, output_type="latent", num_inference_steps=50, generator=torch.Generator('cpu').manual_seed(0))

    latents = out.images 

    sdxl_image = sdxl_pipe.vae.decode(latents / sdxl_pipe.vae.config.scaling_factor).sample[0]

    sdxl_image = ((sdxl_image + 1) * 127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy().transpose(1, 2, 0)

    Image.fromarray(sdxl_image).save(f'sdxl_image_{prompt}.png')

    sd_image = sd_pipe.vae.decode(interposer(latents / sdxl_pipe.vae.config.scaling_factor)).sample[0]

    sd_image = ((sd_image + 1) * 127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy().transpose(1, 2, 0)

    Image.fromarray(sd_image).save(f'sd_image_{prompt}.png')

    sd_con_image = consistency_decoder_vae.decode(interposer(latents / sdxl_pipe.vae.config.scaling_factor)).sample[0]

    sd_con_image = ((sd_con_image + 1) * 127.5).clamp(0, 255).to(dtype=torch.uint8).detach().cpu().numpy().transpose(1, 2, 0)

    Image.fromarray(sd_con_image).save(f'sd_con_image_{prompt}.png')
Birch-san commented 11 months ago

yeah, these results are unsurprising. I've only made one image with the interposer myself, so can't say what to expect in general.

I'd say your interposed image fared better than mine did (mine lost a lot of dynamic range). you can see the desaturation that occurs in the interposer README:

image

I myself have experienced similar problems making tiny latent->RGB converters. for some reason it's hard to teach it to preserve saturation well (I struggled to make my small FFNs learn to reproduce deep reds). maybe need a different loss, or maybe I just needed something with a wider receptive area than a Linear layer.
https://twitter.com/Birchlabs/status/1640824768415842304

and yes, I think we're seeing that the diffusion decoder reduces the dynamic range of the interposed latents further still. happened on yours, just like it did with mine.

so: I don't assume any mistakes were made. these are the results I'd expect.

city96 commented 11 months ago

@Birch-san

Just to chime in, the interposer was never meant to be 100% accurate, it was mostly meant to be used at high-ish (0.4+) denoise as a stop-gap solution so one model could "composite" the image for another model (in this case, using the better prompt comprehension from SDXL).

you can see the desaturation that occurs in the interposer README:

That's just the TAESD preview VS the NovelAI VAE (the later of which is known to produce dull colors like that). That aside, the color accuracy for the interposer in general is terrible since there's no visual loss during training (don't have the hardware for it nor the experience to pull it off).

The dataset isn't very diverse either, since I can't VAE encode/decode the samples on-the-fly. Think it was Flickr2K + DIV2K with each image being cropped into 5 then flipped for like 44K images total.

I remember testing a few different architectures and they preformed basically the same (which makes me thing the problem is with the training code). I'm happy to hear any suggestions on what to do since I don't have much experience with ML stuff.

loss-eval

@williamberman

As for adding it to diffusers, if there's interest, I can clean up the code to be much more general. I have a version which supports different scaling/channels/etc, allowing it to work with other latent spaces such as the wurstchen one. The quality also isn't great but I think it's a better base than the current code. I uploaded a snapshot on a separate branch for anyone curious

williamberman commented 11 months ago

Wow super helpful discussion! NW I think we might hold off on adding for now then but will keep an eye on it fersure. Just trying to be a little additional careful about what gets added to the core library these days.

If you want to add it though under the community examples folder, happy to merge. And regardless will keep an eye on progress :)