YuxinWenRick / hard-prompts-made-easy

MIT License
579 stars 54 forks source link

For SD XL? #21

Closed 0x1355 closed 11 months ago

0x1355 commented 11 months ago

Hello @YuxinWenRick , your paper and repo really helped improve my workflow. Thank you!

Meanwhile, I am wondering if I can apply this approach to SD-XL. It uses two text encoders (ViT-bigG and ViT-L). I found both in the official open_clip repo. But I am not sure how to combine them, like in the diffusers inference pipeline.

Can you point me to the right direction? Thanks.

YuxinWenRick commented 11 months ago

Hi, thanks for your interest.

Yeah, it would be cool to apply PEZ to SDXL. I think the straightforward way is to optimize separate prompts for each text encoder and feed the prompt to the corresponding text encoder. This might require some small modifications to the diffusers pipeline here: https://github.com/huggingface/diffusers/blob/b9feed87958c27074b0618cc543696c05f58e2c9/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py#L316. Instead of having one universal prompt, we can pass a list of prompts. I don't have SDXL model weights for now, but I think it will be public this month, so as I have the model weights, I will play around with it.

For now, maybe the simplest way is just to use the prompt optimized only with ViT-bigG, cause ViT-bigG is the main text-encoder for SDXL (I believe so), so it should kind work. To do so, you can just simply change args.clip_model = "ViT-bigG-14" and args.clip_pretrain = "laion2b_s39b_b160k".

0x1355 commented 11 months ago

Gotcha. Will try it out :sunglasses:

StableInfo commented 10 months ago

Any updates?

manzonif commented 10 months ago

Hi, I've given it a try, but it seems that it's not working as expected. It's not learning. I'm doing my best, but I'm new to Python and Torch, so there might be something I'm overlooking in my code, (even some unforgivable mistakes :-) ). I tried to use "ViT-bigG-14" clip model but it's too big for my 4090 24GB I used last diffusers==0.20.0 Here my prompt inversion:

import open_clip
import torch
from torchvision import transforms
import argparse
import datetime
import os
import copy
from transformers.optimization import Adafactor, AdafactorSchedule
from optim_utils import * 
from diffusers import DDPMScheduler, DPMSolverMultistepScheduler
from modified_stable_diffusion_xl_pipeline import ModifiedStableDiffusionPipelineXL

args = argparse.Namespace()
args.iter = 1000
args.prompt_len = 8
args.lr = 0.1
args.weight_decay = 0.1
args.opt_iters = 3000
args.eval_step = 50
args.prompt_bs = 1
args.loss_weight = 1.0
args.print_step = 100
args.batch_size = 1
# args.clip_model = "ViT-bigG-14"
# args.clip_pretrain =  "laion2b_s39b_b160k"
args.clip_model = "ViT-H-14"
args.clip_pretrain =  "laion2b_s32b_b79k"
best_loss = -999
eval_loss = -99999
best_text = ""

weight_dtype = torch.bfloat16

device = "cuda" if torch.cuda.is_available() else "cpu"

def initialize_prompt(tokenizers_list, token_embeddings_list, args, device):
    prompt_len = args.prompt_len
    # randomly optimize prompt embeddings    
    prompt_embeds_list = []
    dummy_embeds_list = []
    dummy_ids_list = []
    prompt_ids = torch.randint(len(tokenizers_list[0].encoder), (args.prompt_bs, prompt_len)).to(device)
    for tokenizer, token_embeddings in zip(tokenizers_list, token_embeddings_list):

        prompt_embeds = token_embeddings(prompt_ids).detach()
        prompt_embeds.requires_grad = True
        # initialize the template
        # -1 for optimized tokens
        dummy_ids = [tokenizer.bos_token_id] + [-1] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
        dummy_ids = torch.tensor([dummy_ids] * args.prompt_bs).to(device)
        # for getting dummy embeds; -1 won't work for token_embedding
        tmp_dummy_ids = [tokenizer.bos_token_id] + [0] * prompt_len + [tokenizer.eos_token_id] + [0] * (75 - prompt_len)
        tmp_dummy_ids = torch.tensor([tmp_dummy_ids] * args.prompt_bs).to(device)

        dummy_embeds = token_embeddings(tmp_dummy_ids).detach()
        dummy_embeds.requires_grad = False
        prompt_embeds_list.append(prompt_embeds)
        dummy_embeds_list.append(dummy_embeds)
        dummy_ids_list.append(dummy_ids)

    return prompt_embeds_list, dummy_embeds_list, dummy_ids_list

model_id = "stabilityai/stable-diffusion-xl-base-1.0"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
# scheduler = DDPMScheduler(
#     beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
# )

pipe = ModifiedStableDiffusionPipelineXL.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=weight_dtype,
    variant="fp16", 
    use_safetensors=True
)
pipe = pipe.to(device)

pipe.vae.requires_grad_(False)
pipe.vae.eval()

pipe.unet.requires_grad_(True)
pipe.unet.train()

clip_model, _, clip_preprocess = open_clip.create_model_and_transforms(args.clip_model, pretrained=args.clip_pretrain, device=device) 

image_length = 1024
tokenizers_list = [pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [pipe.tokenizer_2]
token_embeddings_list =[pipe.text_encoder.text_model.embeddings.token_embedding, pipe.text_encoder_2.text_model.embeddings.token_embedding]

preprocess = transforms.Compose(
    [
        transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(1024),
        transforms.ToTensor(),
    ]
)

urls = [
        "https://www.parkwestgallery.com/wp-content/uploads/2017/10/im811243-e1507918728745.jpg",
       ]

orig_images = list(filter(None,[download_image(url) for url in urls]))

SDXL_VAE_SCALE_FACTOR = 0.13025

with torch.no_grad():
    curr_images = [preprocess(i).unsqueeze(0) for i in orig_images]
    curr_images = torch.concatenate(curr_images).to(device)
    all_latents = pipe.vae.encode(curr_images.to(weight_dtype)).latent_dist.sample()
    all_latents = all_latents * SDXL_VAE_SCALE_FACTOR

#initialize random prompt 
prompt_embeds_list, dummy_embeds_list, dummy_ids_list = initialize_prompt(tokenizers_list, token_embeddings_list, args, device)
# input_optimizer = Adafactor(prompt_embeds_list, scale_parameter=False, relative_step=False, warmup_init=False, lr=0.2)
input_optimizer = torch.optim.AdamW(prompt_embeds_list, lr=args.lr, weight_decay=args.weight_decay)
input_optim_scheduler = None

for step in range(args.opt_iters):
    padded_embeds_list = []
    padded_dummy_ids_list = []
    tmp_embeds_list = []
    nn_indices_list = []

    # forward projection (top1 semantic_search(prompt_embeds, token_embedding))
    for prompt_embeds, dummy_embeds, dummy_ids, tokenizer, token_embeddings in zip(prompt_embeds_list, dummy_embeds_list, dummy_ids_list, tokenizers_list, token_embeddings_list):    
        projected_embeds, nn_indices = nn_project(prompt_embeds, token_embeddings)

        tmp_embeds = copy.deepcopy(prompt_embeds)
        tmp_embeds.data = projected_embeds.data
        tmp_embeds.requires_grad = True

        # padding and repeat
        padded_embeds = copy.deepcopy(dummy_embeds)
        padded_embeds[:, 1:args.prompt_len+1] = tmp_embeds
        padded_embeds = padded_embeds.repeat(args.batch_size, 1, 1)
        padded_dummy_ids = dummy_ids.repeat(args.batch_size, 1)
        nn_indices_list.append(nn_indices)
        padded_embeds_list.append(padded_embeds)
        padded_dummy_ids_list.append(padded_dummy_ids)
        tmp_embeds_list.append(tmp_embeds)

    # randomly sample sample images and get features
    if args.batch_size is None:
        latents = all_latents
    else:
        perm = torch.randperm(len(all_latents))
        idx = perm[:args.batch_size]
        latents = all_latents[idx]

    # Sample noise that we'll add to the latents
    noise = torch.randn_like(latents)
    bsz = latents.shape[0]
    # Sample a random timestep for each image
    timesteps = torch.randint(0, 1000, (bsz,), device=latents.device)
    timesteps = timesteps.long()

    # Add noise to the latents according to the noise magnitude at each timestep
    noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

    # Get the target for loss depending on the prediction type
    if pipe.scheduler.config.prediction_type == "epsilon":
        target = noise
    elif pipe.scheduler.config.prediction_type == "v_prediction":
        target = pipe.scheduler.get_velocity(latents, noise, timesteps)
    else:
        raise ValueError(f"Unknown prediction type {pipe.scheduler.config.prediction_type}")

    # get text embeddings
    text_embeddings, pooled_prompt_embeds = pipe._get_text_embedding_with_embeddings(padded_dummy_ids_list, padded_embeds_list)

    add_time_ids = pipe._get_add_time_ids(
        (image_length, image_length), (0,0), (image_length, image_length), dtype=prompt_embeds.dtype
    ).to(device)
    add_text_embeds = pooled_prompt_embeds
    # Predict the noise residual and compute loss
    model_pred = pipe.unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings, added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}).sample
    loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="mean")

    prompt_embeds_list = torch.autograd.grad(loss, tmp_embeds_list)
    input_optimizer.step()
    input_optimizer.zero_grad()    

    curr_lr = input_optimizer.param_groups[0]["lr"]

    ### eval
    if step % args.eval_step == 0:
        prompt_1 = decode_ids(nn_indices_list[0], tokenizers_list[0])[0]
        prompt_2 = decode_ids(nn_indices_list[1], tokenizers_list[1])[0]
        print(f"step: {step}, lr: {curr_lr}, cosim: {eval_loss:.3f}, best_cosim: {best_loss:.3f}, best prompt: {best_text}")

        with torch.no_grad():
            pred_imgs = pipe(
                prompt_1,
                prompt_2,
                num_images_per_prompt=4,
                guidance_scale=9,
                num_inference_steps=50,
                height=image_length,
                width=image_length,
                output_type='pil'
                ).images
            eval_loss = measure_similarity(orig_images, pred_imgs, clip_model, clip_preprocess, device)

        if best_loss < eval_loss:
            best_loss = eval_loss
            best_text = f'{prompt_1} {prompt_2}'   

print()
print(f"Best shot: consine similarity: {best_loss:.3f}")
print(f"text: {best_text}")
# you can customize the learned prompt here
prompt = best_text

num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    prompt,
    num_images_per_prompt=num_images,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    height=image_length,
    width=image_length,
    ).images

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
for i, img in enumerate(images):
    img.save(os.path.join('output/', f"sd2_result_{timestamp}_{i:03d}.png"))

print("Save images.")

Here the modified Pipeline:

from typing import Callable, List, Optional, Union

import torch
from diffusers import StableDiffusionXLPipeline
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPTextModelOutput
from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
from diffusers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPooling
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

class ModifiedStableDiffusionPipelineXL(StableDiffusionXLPipeline):
    def __init__(self,
        vae,
        text_encoder: CLIPTextModel,
        text_encoder_2: CLIPTextModelWithProjection,
        tokenizer: CLIPTokenizer,
        tokenizer_2: CLIPTokenizer,
        unet,
        scheduler,
        force_zeros_for_empty_prompt: bool = True,
        add_watermarker: Optional[bool] = None
    ):
        super(ModifiedStableDiffusionPipelineXL, self).__init__(vae,
                text_encoder,
                text_encoder_2,
                tokenizer,
                tokenizer_2,
                unet,
                scheduler,
                force_zeros_for_empty_prompt,
                add_watermarker)

    def _build_causal_attention_mask(self,bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask

    def _encode_embeddings(self, text_encoder, input_ids, prompt_embeddings, attention_mask=None):
        output_attentions = text_encoder.text_model.config.output_attentions
        output_hidden_states = True
        return_dict = text_encoder.text_model.config.use_return_dict

        hidden_states = text_encoder.text_model.embeddings(inputs_embeds=prompt_embeddings)

        bsz, seq_len = input_ids.shape[0], input_ids.shape[1]
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
            hidden_states.device
        )

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = text_encoder.text_model._expand_mask(attention_mask, hidden_states.dtype)

        encoder_outputs = text_encoder.text_model.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
            input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
        ]

        text_outputs = BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )
        if isinstance(text_encoder, CLIPTextModelWithProjection):
            pooled_output = text_outputs[1]

            text_embeds = text_encoder.text_projection(pooled_output)

            if not return_dict:
                outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
                return tuple(output for output in outputs if output is not None)

            return CLIPTextModelOutput(
                text_embeds=text_embeds,
                last_hidden_state=text_outputs.last_hidden_state,
                hidden_states=text_outputs.hidden_states,
                attentions=text_outputs.attentions,
            )

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return text_outputs

    def _get_text_embedding_with_embeddings(self, text_input_ids_list, prompt_embeddings_list):
        text_encoders_list = (
            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
        )
        prompt_embeds_list = []
        for text_input_ids, prompt_embeddings, text_encoder in zip(text_input_ids_list, prompt_embeddings_list, text_encoders_list):
            text_embeddings = self._encode_embeddings(
                text_encoder,
                text_input_ids,
                prompt_embeddings
            )
             # We are only ALWAYS interested in the pooled output of the final text encoder
            pooled_prompt_embeds = text_embeddings[0]
            text_embeddings = text_embeddings.hidden_states[-2]
            prompt_embeds_list.append(text_embeddings)

        prompt_embeds = torch.cat(prompt_embeds_list, dim=-1)
        return prompt_embeds, pooled_prompt_embeds
YuxinWenRick commented 10 months ago

Hi manzonif, thank you for sharing the details. I have been busy with a conference deadline recently, but I will try my best to test it either this month or the next. I appreciate your understanding and patience.

To delve a bit deeper into the conceptual framework I had in mind earlier, there are two ways I am considering:

  1. Optimizing two independent prompts for the two text encoders.
  2. Optimizing a universal prompt using an ensemble of two text encoders.
manzonif commented 10 months ago

Certainly! I wish you a good conference.

YuxinWenRick commented 8 months ago

Hi @manzonif , sorry about the late response. Not sure if you have any progress on this, but I recently tried to optimize two independent prompts for the two text encoders. However, it doesn't work very well. I am going to double-check the code and also see if optimizing a universal prompt with an ensemble of two text encoders works.

Thanks for your patience!