Closed 0x1355 closed 11 months ago
Hi, thanks for your interest.
Yeah, it would be cool to apply PEZ to SDXL. I think the straightforward way is to optimize separate prompts for each text encoder and feed the prompt to the corresponding text encoder. This might require some small modifications to the diffusers pipeline here: https://github.com/huggingface/diffusers/blob/b9feed87958c27074b0618cc543696c05f58e2c9/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L316. Instead of having one universal prompt, we can pass a list of prompts. I don't have SDXL model weights for now, but I think it will be public this month, so as I have the model weights, I will play around with it.
For now, maybe the simplest way is just to use the prompt optimized only with ViT-bigG, cause ViT-bigG is the main text-encoder for SDXL (I believe so), so it should kind work. To do so, you can just simply change args.clip_model = "ViT-bigG-14"
and args.clip_pretrain = "laion2b_s39b_b160k"
.
Gotcha. Will try it out :sunglasses:
Any updates?
Hi, I've given it a try, but it seems that it's not working as expected. It's not learning. I'm doing my best, but I'm new to Python and Torch, so there might be something I'm overlooking in my code, (even some unforgivable mistakes :-) ). I tried to use "ViT-bigG-14" clip model but it's too big for my 4090 24GB I used last diffusers==0.20.0 Here my prompt inversion:
import open_clip
import torch
from torchvision import transforms
import argparse
import datetime
import os
import copy
from transformers.optimization import Adafactor, AdafactorSchedule
from optim_utils import *
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from modified_stable_diffusion_xl_pipeline import ModifiedStableDiffusionPipelineXL
args = argparse.Namespace()
args.iter = 1000
args.prompt_len = 8
args.lr = 0.1
args.weight_decay = 0.1
args.opt_iters = 3000
args.eval_step = 50
args.prompt_bs = 1
args.loss_weight = 1.0
args.print_step = 100
args.batch_size = 1
# args.clip_model = "ViT-bigG-14"
# args.clip_pretrain = "laion2b_s39b_b160k"
args.clip_model = "ViT-H-14"
args.clip_pretrain = "laion2b_s32b_b79k"
best_loss = -999
eval_loss = -99999
best_text = ""
weight_dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
def initialize_prompt(tokenizers_list, token_embeddings_list, args, device):
prompt_len = args.prompt_len
# randomly optimize prompt embeddings
prompt_embeds_list = []
dummy_embeds_list = []
dummy_ids_list = []
prompt_ids = torch.randint(len(tokenizers_list[0].encoder), (args.prompt_bs, prompt_len)).to(device)
for tokenizer, token_embeddings in zip(tokenizers_list, token_embeddings_list):
prompt_embeds = token_embeddings(prompt_ids).detach()
prompt_embeds.requires_grad = True
# initialize the template
# -1 for optimized tokens
dummy_ids = [tokenizer.bos_token_id] + [-1] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
dummy_ids = torch.tensor([dummy_ids] * args.prompt_bs).to(device)
# for getting dummy embeds; -1 won't work for token_embedding
tmp_dummy_ids = [tokenizer.bos_token_id] + [0] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
tmp_dummy_ids = torch.tensor([tmp_dummy_ids] * args.prompt_bs).to(device)
dummy_embeds = token_embeddings(tmp_dummy_ids).detach()
dummy_embeds.requires_grad = False
prompt_embeds_list.append(prompt_embeds)
dummy_embeds_list.append(dummy_embeds)
dummy_ids_list.append(dummy_ids)
return prompt_embeds_list, dummy_embeds_list, dummy_ids_list
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
# scheduler = DDPMScheduler(
# beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )
pipe = ModifiedStableDiffusionPipelineXL.from_pretrained(
model_id,
scheduler=scheduler,
torch_dtype=weight_dtype,
variant="fp16",
use_safetensors=True
)
pipe = pipe.to(device)
pipe.vae.requires_grad_(False)
pipe.vae.eval()
pipe.unet.requires_grad_(True)
pipe.unet.train()
clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device)
image_length = 1024
tokenizers_list = [pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [pipe.tokenizer_2]
token_embeddings_list =[pipe.text_encoder.text_model.embeddings.token_embedding, pipe.text_encoder_2.text_model.embeddings.token_embedding]
preprocess = transforms.Compose(
[
transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(1024),
transforms.ToTensor(),
]
)
urls = [
"https://www.parkwestgallery.com/wp-content/uploads/2017/10/im811243-e1507918728745.jpg",
]
orig_images = list(filter(None,[download_image(url) for url in urls]))
SDXL_VAE_SCALE_FACTOR = 0.13025
with torch.no_grad():
curr_images = [preprocess(i).unsqueeze(0) for i in orig_images]
curr_images = torch.concatenate(curr_images).to(device)
all_latents = pipe.vae.encode(curr_images.to(weight_dtype)).latent_dist.sample()
all_latents = all_latents * SDXL_VAE_SCALE_FACTOR
#initialize random prompt
prompt_embeds_list, dummy_embeds_list, dummy_ids_list = initialize_prompt(tokenizers_list, token_embeddings_list, args, device)
# input_optimizer = Adafactor(prompt_embeds_list, scale_parameter=False, relative_step=False, warmup_init=False, lr=0.2)
input_optimizer = torch.optim.AdamW(prompt_embeds_list, lr=args.lr, weight_decay=args.weight_decay)
input_optim_scheduler = None
for step in range(args.opt_iters):
padded_embeds_list = []
padded_dummy_ids_list = []
tmp_embeds_list = []
nn_indices_list = []
# forward projection (top1 semantic_search(prompt_embeds, token_embedding))
for prompt_embeds, dummy_embeds, dummy_ids, tokenizer, token_embeddings in zip(prompt_embeds_list, dummy_embeds_list, dummy_ids_list, tokenizers_list, token_embeddings_list):
projected_embeds, nn_indices = nn_project(prompt_embeds, token_embeddings)
tmp_embeds = copy.deepcopy(prompt_embeds)
tmp_embeds.data = projected_embeds.data
tmp_embeds.requires_grad = True
# padding and repeat
padded_embeds = copy.deepcopy(dummy_embeds)
padded_embeds[:, 1:args.prompt_len+1] = tmp_embeds
padded_embeds = padded_embeds.repeat(args.batch_size, 1, 1)
padded_dummy_ids = dummy_ids.repeat(args.batch_size, 1)
nn_indices_list.append(nn_indices)
padded_embeds_list.append(padded_embeds)
padded_dummy_ids_list.append(padded_dummy_ids)
tmp_embeds_list.append(tmp_embeds)
# randomly sample sample images and get features
if args.batch_size is None:
latents = all_latents
else:
perm = torch.randperm(len(all_latents))
idx = perm[:args.batch_size]
latents = all_latents[idx]
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)
# Get the target for loss depending on the prediction type
if pipe.scheduler.config.prediction_type == "epsilon":
target = noise
elif pipe.scheduler.config.prediction_type == "v_prediction":
target = pipe.scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {pipe.scheduler.config.prediction_type}")
# get text embeddings
text_embeddings, pooled_prompt_embeds = pipe._get_text_embedding_with_embeddings(padded_dummy_ids_list, padded_embeds_list)
add_time_ids = pipe._get_add_time_ids(
(image_length, image_length), (0,0), (image_length, image_length), dtype=prompt_embeds.dtype
).to(device)
add_text_embeds = pooled_prompt_embeds
# Predict the noise residual and compute loss
model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}).sample
loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
prompt_embeds_list = torch.autograd.grad(loss, tmp_embeds_list)
input_optimizer.step()
input_optimizer.zero_grad()
curr_lr = input_optimizer.param_groups[0]["lr"]
### eval
if step % args.eval_step == 0:
prompt_1 = decode_ids(nn_indices_list[0], tokenizers_list[0])[0]
prompt_2 = decode_ids(nn_indices_list[1], tokenizers_list[1])[0]
print(f"step: {step}, lr: {curr_lr}, cosim: {eval_loss:.3f}, best_cosim: {best_loss:.3f}, best prompt: {best_text}")
with torch.no_grad():
pred_imgs = pipe(
prompt_1,
prompt_2,
num_images_per_prompt=4,
guidance_scale=9,
num_inference_steps=50,
height=image_length,
width=image_length,
output_type='pil'
).images
eval_loss = measure_similarity(orig_images, pred_imgs, clip_model, clip_preprocess, device)
if best_loss < eval_loss:
best_loss = eval_loss
best_text = f'{prompt_1} {prompt_2}'
print()
print(f"Best shot: consine similarity: {best_loss:.3f}")
print(f"text: {best_text}")
# you can customize the learned prompt here
prompt = best_text
num_images = 4
guidance_scale = 9
num_inference_steps = 25
images = pipe(
prompt,
num_images_per_prompt=num_images,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
height=image_length,
width=image_length,
).images
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(images):
img.save(os.path.join('output/', f"sd2_result_{timestamp}_{i:03d}.png"))
print("Save images.")
Here the modified Pipeline:
from typing import Callable, List, Optional, Union
import torch
from diffusers import StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPooling
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ModifiedStableDiffusionPipelineXL(StableDiffusionXLPipeline):
def __init__(self,
vae,
text_encoder: CLIPTextModel,
text_encoder_2: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet,
scheduler,
force_zeros_for_empty_prompt: bool = True,
add_watermarker: Optional[bool] = None
):
super(ModifiedStableDiffusionPipelineXL, self).__init__(vae,
text_encoder,
text_encoder_2,
tokenizer,
tokenizer_2,
unet,
scheduler,
force_zeros_for_empty_prompt,
add_watermarker)
def _build_causal_attention_mask(self,bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
def _encode_embeddings(self, text_encoder, input_ids, prompt_embeddings, attention_mask=None):
output_attentions = text_encoder.text_model.config.output_attentions
output_hidden_states = True
return_dict = text_encoder.text_model.config.use_return_dict
hidden_states = text_encoder.text_model.embeddings(inputs_embeds=prompt_embeddings)
bsz, seq_len = input_ids.shape[0], input_ids.shape[1]
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = text_encoder.text_model._expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = text_encoder.text_model.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
text_outputs = BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
if isinstance(text_encoder, CLIPTextModelWithProjection):
pooled_output = text_outputs[1]
text_embeds = text_encoder.text_projection(pooled_output)
if not return_dict:
outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
return tuple(output for output in outputs if output is not None)
return CLIPTextModelOutput(
text_embeds=text_embeds,
last_hidden_state=text_outputs.last_hidden_state,
hidden_states=text_outputs.hidden_states,
attentions=text_outputs.attentions,
)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return text_outputs
def _get_text_embedding_with_embeddings(self, text_input_ids_list, prompt_embeddings_list):
text_encoders_list = (
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
)
prompt_embeds_list = []
for text_input_ids, prompt_embeddings, text_encoder in zip(text_input_ids_list, prompt_embeddings_list, text_encoders_list):
text_embeddings = self._encode_embeddings(
text_encoder,
text_input_ids,
prompt_embeddings
)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = text_embeddings[0]
text_embeddings = text_embeddings.hidden_states[-2]
prompt_embeds_list.append(text_embeddings)
prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
return prompt_embeds, pooled_prompt_embeds
Hi manzonif, thank you for sharing the details. I have been busy with a conference deadline recently, but I will try my best to test it either this month or the next. I appreciate your understanding and patience.
To delve a bit deeper into the conceptual framework I had in mind earlier, there are two ways I am considering:
Certainly! I wish you a good conference.
Hi @manzonif , sorry about the late response. Not sure if you have any progress on this, but I recently tried to optimize two independent prompts for the two text encoders. However, it doesn't work very well. I am going to double-check the code and also see if optimizing a universal prompt with an ensemble of two text encoders works.
Thanks for your patience!
Hello @YuxinWenRick , your paper and repo really helped improve my workflow. Thank you!
Meanwhile, I am wondering if I can apply this approach to SD-XL. It uses two text encoders (ViT-bigG and ViT-L). I found both in the official open_clip repo. But I am not sure how to combine them, like in the diffusers inference pipeline.
Can you point me to the right direction? Thanks.