LingxiaoYang2023 / DSG2024

Official pytorch repository for “Guidance with Spherical Gaussian Constraint for Conditional Diffusion”
43 stars 2 forks source link

Use DSG in latent diffusion #7

Open 2019211753 opened 2 months ago

2019211753 commented 2 months ago

I try to use clip text guidance instead of CFG in SD-Style, but the result seems not satisfatory, could u please help me find what's going wrong? Here is the code, I only change three lines in the file 'SD_style/ldm/models/diffusion/ddim.py' at line 232、262、264

c_in = torch.cat([unconditional_conditioning, unconditional_conditioning])  # remove CFG
residual = self.conditional_fn(D_x0_t, self.ref_img)
residual2 = self.conditional_class.get_residual(D_x0_t, 'A cat')
norm = torch.linalg.norm(residual) + torch.linalg.norm(residual2) # add clip text guidance
norm_grad = torch.autograd.grad(outputs=norm, inputs=x)[0] 

The reference image is jojo.jpeg and the outpyt is: image

2019211753 commented 2 months ago

Is it possible to guide Stable Diffusion using only CLIP guidance with a given text?

LingxiaoYang2023 commented 2 months ago

Thanks for your interest in our paper!

Although we have demonstrated the superiority of DSG in many tasks, we have not yet experimented with more Loss Functions such as CLIP Loss. Therefore, I can only offer some potential solutions for the code you provided:

  1. When calculating $\nabla_{x_t} L(\hat{x0},y)$, it is necessary to differentiate to the U-Net. Therefore, you should not use "with torch.nograd():" before _”noise_pred = unet(latent_model_input, t, encoder_hidden_states=textembeddings).sample“
  2. You might need to use a "time travel" strategy at intermediate times (e.g., t=[300,700]). You could refer to the paper and code from Freedom. This is because loss functions like CLIP and Style are more challenging than linear inverse problems.
  3. There may be some minor issues with data processing, as I initially used unconditional generation and found the image quality to be poor.
  4. Stable Diffusion was trained using Classifier Free Guidance (with a 10% probability of training the unconditional model), thus its capability for unconditional generation may not compare well with conditional generation. Therefore, I recommend using a model trained purely with $L_{simple}$ for unconditional generation as the pre-trained model in your experiments.

To address these issues, I can provide a reference code that addresses problems 1 and 3 (based on the code you provided). It might be helpful to you as a good starting point. (although currently, due to the absence of time travel, the result is still not good enough).

Prompt: "a photograph of a dog" Result: test Code:

from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from typing import Optional
from diffusers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
import torchvision.transforms as transforms
import torch.nn.functional as F
import open_clip
import torchvision

torch_device = "cuda:0"

def downscale(img):
    transform = transforms.ToTensor()
    img = transform(img).half().to(torch_device)
    img = img.unsqueeze(0)
    img = F.interpolate(
        img, (512, 512), mode="bilinear", align_corners=False
    )
    return img

def encode_image(img):
    img = img * 2.0 - 1.0
    posterior = vae.encode(img.latent_dist)
    latent = posterior.sample() * vae.config.scaling_factor
    return latent

def decode_latent(latent):
    latent = 1 / vae.config.scaling_factor * latent
    image = vae.decode(latent).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    return image

def upscale(image):
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    image = (image * 255).round().astype("uint8")
    return image

class DDIMScheduler_with_more_output(DDIMScheduler):
    def full_output(
            self,
            model_output: torch.FloatTensor,
            timestep: int,
            sample: torch.FloatTensor,
            eta: float = 0.0,
            use_clipped_model_output: bool = False,
            generator=None,
            variance_noise: Optional[torch.FloatTensor] = None,
            return_dict: bool = True,
    ):
        # 1. get previous step value (=t-1)
        prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

        # 2. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        if self.config.prediction_type == "epsilon":
            pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
            pred_epsilon = model_output
        elif self.config.prediction_type == "sample":
            pred_original_sample = model_output
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
        elif self.config.prediction_type == "v_prediction":
            pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
            pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
        else:
            raise ValueError(
                f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
                " `v_prediction`"
            )

        # 4. Clip or threshold "predicted x_0"
        if self.config.thresholding:
            pred_original_sample = self._threshold_sample(pred_original_sample)
        elif self.config.clip_sample:
            pred_original_sample = pred_original_sample.clamp(
                -self.config.clip_sample_range, self.config.clip_sample_range
            )

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = self._get_variance(timestep, prev_timestep)
        std_dev_t = eta * variance ** (0.5)

        if use_clipped_model_output:
            # the pred_epsilon is always re-derived from the clipped x_0 in Glide
            pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)

        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction

        if eta > 0:
            if variance_noise is not None and generator is not None:
                raise ValueError(
                    "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
                    " `variance_noise` stays `None`."
                )

            if variance_noise is None:
                variance_noise = randn_tensor(
                    model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
                )
            variance = std_dev_t * variance_noise

            prev_sample = prev_mean + variance
        else:
            prev_sample = prev_mean

        # return {'prev_mean':prev_mean, 'variance':std_dev_t,'prev_sample':prev_sample}

        return {'prev_mean': prev_mean, 'variance': std_dev_t, 'z0': pred_original_sample}

tfms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomResizedCrop(224),  # 随机裁剪
        torchvision.transforms.RandomAffine(5),  # 随机扭曲图片
        torchvision.transforms.RandomHorizontalFlip(),  # 随机左右镜像,
    ]
)

vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", use_safetensors=True)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder", use_safetensors=True
)
unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-2-1-base", subfolder="unet", use_safetensors=True
)
scheduler = DDIMScheduler_with_more_output.from_pretrained("stabilityai/stable-diffusion-2-1-base",
                                                           subfolder="scheduler")
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)

height = 512  # default height of Stable Diffusion
width = 512  # default width of Stable Diffusion
num_inference_steps = 200  # Number of denoising steps

uncond_input = tokenizer([""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings])

latents = torch.randn(
    (1, unet.config.in_channels, height // 8, width // 8),
    device=torch_device,
)

from tqdm.auto import tqdm

prompt = "a photograph of a dog"
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s34b_b88k')

clip_model.to(torch_device)

def clip_loss(x, text_features):
    image = decode_latent(x)

    image_features = clip_model.encode_image(
        tfms(image)
    )
    input_normed = torch.nn.functional.normalize(image_features.
                                                 unsqueeze(1), dim=2)
    embed_normed = torch.nn.functional.normalize(text_features.
                                                 unsqueeze(0), dim=2)
    dists = (
        input_normed.sub(embed_normed).norm(dim=2).div(2).
        arcsin().pow(2).mul(2)
    )
    return dists.mean()

guidance_scale = 0.1
eps = 1e-20

text = open_clip.tokenize([prompt]).to(torch_device)
with torch.no_grad(), torch.cuda.amp.autocast():
    text_features = clip_model.encode_text(text)

scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma

for t in tqdm(scheduler.timesteps):
    latents.requires_grad_(True)
    latent_model_input = scheduler.scale_model_input(latents, timestep=t)

    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # if not t <= 100:
    if True:
        output = scheduler.full_output(noise_pred, t, latents, eta=1)
        prev_mean = output['prev_mean']
        variance = output['variance']
        z0 = output['z0']
        loss = clip_loss(z0, text_features)
        loss = torch.linalg.norm(loss)
        grad = torch.autograd.grad(loss, latents)[0]
        grad_norm = torch.norm(grad)
        grad2 = grad / (grad_norm + eps)
        batch, ch, h, w = prev_mean.shape
        import math

        r = math.sqrt(ch * h * w) * variance
        d_star = -r * grad2
        noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
        d_sample = variance * noise
        mix_direction = d_sample + guidance_scale * (d_star - d_sample)
        mix_direction_norm = torch.norm(mix_direction)
        latents = prev_mean + mix_direction / (mix_direction_norm + eps) * r
    else:
        output = scheduler.full_output(noise_pred, t, latents, eta=1)
        prev_mean = output['prev_mean']
        variance = output['variance']
        noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
        latents = prev_mean + noise * variance
    latents = latents.detach()

# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
    image = vae.decode(latents).sample

image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
image = Image.fromarray(image)
image.save('test.png')
2019211753 commented 2 months ago

Thanks for you assistance very much! You save my life! @LingxiaoYang2023

2019211753 commented 2 months ago

I tried time travel but the reuslt seems still unsatisfactory:

for t in tqdm(scheduler.timesteps):
    latents.requires_grad_(True)
    latent_model_input = scheduler.scale_model_input(latents, timestep=t)

    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
    if 400 <= t < 700:
        repeat = 2
        for i in range(repeat + 1):
            output = scheduler.full_output(noise_pred, t, latents, eta=1)
            prev_mean = output['prev_mean']
            variance = output['variance']
            z0 = output['z0']
            beta_t = output['beta_t']
            loss = clip_loss(z0, text_features)
            loss = torch.linalg.norm(loss)
            grad = torch.autograd.grad(loss, latents)[0]
            grad_norm = torch.norm(grad)
            grad2 = grad / (grad_norm + eps)
            batch, ch, h, w = prev_mean.shape
            import math

            r = math.sqrt(ch * h * w) * variance
            d_star = -r * grad2
            noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
            d_sample = variance * noise
            mix_direction = d_sample + guidance_scale * (d_star - d_sample)
            mix_direction_norm = torch.norm(mix_direction)
            latents = prev_mean + mix_direction / (mix_direction_norm + eps) * r
            if i < repeat:
                noise2 = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(
                    prev_mean.device)
                latents = (1 - beta_t) ** 0.5 * latents + beta_t ** 0.5 * noise2
    else:
        output = scheduler.full_output(noise_pred, t, latents, eta=1)
        prev_mean = output['prev_mean']
        variance = output['variance']
        noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
        latents = prev_mean + noise * variance
    latents = latents.detach()

Here is the result: image I am wondering if there is an issue with my code, maybe i should try a unconditional model. Additionally, according to DDIM original paper, when using DDIM, convert $x_{t-1}$ to $x_t$ should use $q(xt)|x{t-1}, x_0)$, isn't it? image

Thank you in advance for your reply once again.