Open 2019211753 opened 2 months ago
Is it possible to guide Stable Diffusion using only CLIP guidance with a given text?
Thanks for your interest in our paper!
Although we have demonstrated the superiority of DSG in many tasks, we have not yet experimented with more Loss Functions such as CLIP Loss. Therefore, I can only offer some potential solutions for the code you provided:
To address these issues, I can provide a reference code that addresses problems 1 and 3 (based on the code you provided). It might be helpful to you as a good starting point. (although currently, due to the absence of time travel, the result is still not good enough).
Prompt: "a photograph of a dog" Result: Code:
from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
from typing import Optional
from diffusers import DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
import torchvision.transforms as transforms
import torch.nn.functional as F
import open_clip
import torchvision
torch_device = "cuda:0"
def downscale(img):
transform = transforms.ToTensor()
img = transform(img).half().to(torch_device)
img = img.unsqueeze(0)
img = F.interpolate(
img, (512, 512), mode="bilinear", align_corners=False
)
return img
def encode_image(img):
img = img * 2.0 - 1.0
posterior = vae.encode(img.latent_dist)
latent = posterior.sample() * vae.config.scaling_factor
return latent
def decode_latent(latent):
latent = 1 / vae.config.scaling_factor * latent
image = vae.decode(latent).sample
image = (image / 2 + 0.5).clamp(0, 1)
return image
def upscale(image):
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")
return image
class DDIMScheduler_with_more_output(DDIMScheduler):
def full_output(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
variance_noise: Optional[torch.FloatTensor] = None,
return_dict: bool = True,
):
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t ** 0.5) * sample - (beta_prod_t ** 0.5) * model_output
pred_epsilon = (alpha_prod_t ** 0.5) * model_output + (beta_prod_t ** 0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t ** 2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if eta > 0:
if variance_noise is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
" `variance_noise` stays `None`."
)
if variance_noise is None:
variance_noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
)
variance = std_dev_t * variance_noise
prev_sample = prev_mean + variance
else:
prev_sample = prev_mean
# return {'prev_mean':prev_mean, 'variance':std_dev_t,'prev_sample':prev_sample}
return {'prev_mean': prev_mean, 'variance': std_dev_t, 'z0': pred_original_sample}
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224), # 随机裁剪
torchvision.transforms.RandomAffine(5), # 随机扭曲图片
torchvision.transforms.RandomHorizontalFlip(), # 随机左右镜像,
]
)
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae", use_safetensors=True)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder", use_safetensors=True
)
unet = UNet2DConditionModel.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", subfolder="unet", use_safetensors=True
)
scheduler = DDIMScheduler_with_more_output.from_pretrained("stabilityai/stable-diffusion-2-1-base",
subfolder="scheduler")
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)
height = 512 # default height of Stable Diffusion
width = 512 # default width of Stable Diffusion
num_inference_steps = 200 # Number of denoising steps
uncond_input = tokenizer([""] * 1, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
text_embeddings = torch.cat([uncond_embeddings])
latents = torch.randn(
(1, unet.config.in_channels, height // 8, width // 8),
device=torch_device,
)
from tqdm.auto import tqdm
prompt = "a photograph of a dog"
clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s34b_b88k')
clip_model.to(torch_device)
def clip_loss(x, text_features):
image = decode_latent(x)
image_features = clip_model.encode_image(
tfms(image)
)
input_normed = torch.nn.functional.normalize(image_features.
unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.
unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).
arcsin().pow(2).mul(2)
)
return dists.mean()
guidance_scale = 0.1
eps = 1e-20
text = open_clip.tokenize([prompt]).to(torch_device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
scheduler.set_timesteps(num_inference_steps)
latents = latents * scheduler.init_noise_sigma
for t in tqdm(scheduler.timesteps):
latents.requires_grad_(True)
latent_model_input = scheduler.scale_model_input(latents, timestep=t)
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# if not t <= 100:
if True:
output = scheduler.full_output(noise_pred, t, latents, eta=1)
prev_mean = output['prev_mean']
variance = output['variance']
z0 = output['z0']
loss = clip_loss(z0, text_features)
loss = torch.linalg.norm(loss)
grad = torch.autograd.grad(loss, latents)[0]
grad_norm = torch.norm(grad)
grad2 = grad / (grad_norm + eps)
batch, ch, h, w = prev_mean.shape
import math
r = math.sqrt(ch * h * w) * variance
d_star = -r * grad2
noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
d_sample = variance * noise
mix_direction = d_sample + guidance_scale * (d_star - d_sample)
mix_direction_norm = torch.norm(mix_direction)
latents = prev_mean + mix_direction / (mix_direction_norm + eps) * r
else:
output = scheduler.full_output(noise_pred, t, latents, eta=1)
prev_mean = output['prev_mean']
variance = output['variance']
noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
latents = prev_mean + noise * variance
latents = latents.detach()
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
image = Image.fromarray(image)
image.save('test.png')
Thanks for you assistance very much! You save my life! @LingxiaoYang2023
I tried time travel but the reuslt seems still unsatisfactory:
for t in tqdm(scheduler.timesteps):
latents.requires_grad_(True)
latent_model_input = scheduler.scale_model_input(latents, timestep=t)
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
if 400 <= t < 700:
repeat = 2
for i in range(repeat + 1):
output = scheduler.full_output(noise_pred, t, latents, eta=1)
prev_mean = output['prev_mean']
variance = output['variance']
z0 = output['z0']
beta_t = output['beta_t']
loss = clip_loss(z0, text_features)
loss = torch.linalg.norm(loss)
grad = torch.autograd.grad(loss, latents)[0]
grad_norm = torch.norm(grad)
grad2 = grad / (grad_norm + eps)
batch, ch, h, w = prev_mean.shape
import math
r = math.sqrt(ch * h * w) * variance
d_star = -r * grad2
noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
d_sample = variance * noise
mix_direction = d_sample + guidance_scale * (d_star - d_sample)
mix_direction_norm = torch.norm(mix_direction)
latents = prev_mean + mix_direction / (mix_direction_norm + eps) * r
if i < repeat:
noise2 = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(
prev_mean.device)
latents = (1 - beta_t) ** 0.5 * latents + beta_t ** 0.5 * noise2
else:
output = scheduler.full_output(noise_pred, t, latents, eta=1)
prev_mean = output['prev_mean']
variance = output['variance']
noise = torch.randn(prev_mean.shape, device=prev_mean.device, dtype=prev_mean.dtype).to(prev_mean.device)
latents = prev_mean + noise * variance
latents = latents.detach()
Here is the result: I am wondering if there is an issue with my code, maybe i should try a unconditional model. Additionally, according to DDIM original paper, when using DDIM, convert $x_{t-1}$ to $x_t$ should use $q(xt)|x{t-1}, x_0)$, isn't it?
Thank you in advance for your reply once again.
I try to use clip text guidance instead of CFG in SD-Style, but the result seems not satisfatory, could u please help me find what's going wrong? Here is the code, I only change three lines in the file 'SD_style/ldm/models/diffusion/ddim.py' at line 232、262、264
The reference image is jojo.jpeg and the outpyt is: