beichenzbc / Long-CLIP

[ECCV 2024] official code for "Long-CLIP: Unlocking the Long-Text Capability of CLIP"
444 stars 22 forks source link

Plug-and-Play text to image generation #6

Closed zhentingqi closed 2 months ago

zhentingqi commented 3 months ago

Hi! This is a cool project. Could you please share some demo code snippets for "Plug-and-Play text to image generation"? Thanks!

beichenzbc commented 3 months ago

Thanks for your recognition. You could refer to this link https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py and rewrite the "encode_prompt funtion". Here, we may give a simple implementation which doesn't take many attributes into account.

def encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds = None, negative_prompt_embeds = None, lora_scale = None, clip_skip = None, ):

    model, preprocess = longclip.load("longclip-L.pt", device=device)
    random.seed()       

    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if prompt_embeds is None:
        text_inputs = longclip.tokenize(
            prompt,
            truncate=True,
        )

        if clip_skip is None:
            prompt_embeds = model.encode_text(text_inputs.to(device))
            prompt_embeds = prompt_embeds

    prompt_embeds_dtype = torch.float16

    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
    bs_embed, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size

        else:
            uncond_tokens = negative_prompt            

        max_length = prompt_embeds.shape[1]
        uncond_input = longclip.tokenize(
            uncond_tokens,
            truncate=True,
        )

        negative_prompt_embeds = model.encode_text_full(
            uncond_input.to(device),
        )
        negative_prompt_embeds = negative_prompt_embeds

    if do_classifier_free_guidance:
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, negative_prompt_embeds

Important: remember to comment the final code for encode_text in clip model:

def encode_text(self, text):

    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
    x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x
SnailForce commented 1 week ago

Thanks for your recognition. You could refer to this link https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py and rewrite the "encode_prompt funtion". Here, we may give a simple implementation which doesn't take many attributes into account.

def encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds = None, negative_prompt_embeds = None, lora_scale = None, clip_skip = None, ):

    model, preprocess = longclip.load("longclip-L.pt", device=device)
    random.seed()       

    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if prompt_embeds is None:
        text_inputs = longclip.tokenize(
            prompt,
            truncate=True,
        )

        if clip_skip is None:
            prompt_embeds = model.encode_text(text_inputs.to(device))
            prompt_embeds = prompt_embeds

    prompt_embeds_dtype = torch.float16

    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
    bs_embed, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size

        else:
            uncond_tokens = negative_prompt            

        max_length = prompt_embeds.shape[1]
        uncond_input = longclip.tokenize(
            uncond_tokens,
            truncate=True,
        )

        negative_prompt_embeds = model.encode_text_full(
            uncond_input.to(device),
        )
        negative_prompt_embeds = negative_prompt_embeds

    if do_classifier_free_guidance:
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, negative_prompt_embeds

Important: remember to comment the final code for encode_text in clip model:

def encode_text(self, text):

    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
    x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x

Why do you choose to use upscale.py for generation?

beichenzbc commented 1 week ago

This is a simple demo on fast usage Long-CLIP on SD1.5. In fact we made some modifications in the code in diffusers. We will release an official demo on using Long-CLIP on SDXL in a week.

beichenzbc commented 1 week ago

Sorry for the typo, it should be https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py However, we still recommended refering to the later SDXL.

SnailForce commented 1 week ago

Sorry for the typo, it should be https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py However, we still recommended refering to the later SDXL.

OK, thank you for your reply. I have some difficulties in converting the implementation of comfyui-longclip into code. I look forward to the generation demo of sd1.5 mentioned in your paper.

SnailForce commented 1 week ago

Thanks for your recognition. You could refer to this link https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py and rewrite the "encode_prompt funtion". Here, we may give a simple implementation which doesn't take many attributes into account.

def encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt=None, prompt_embeds = None, negative_prompt_embeds = None, lora_scale = None, clip_skip = None, ):

    model, preprocess = longclip.load("longclip-L.pt", device=device)
    random.seed()       

    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    if prompt_embeds is None:
        text_inputs = longclip.tokenize(
            prompt,
            truncate=True,
        )

        if clip_skip is None:
            prompt_embeds = model.encode_text(text_inputs.to(device))
            prompt_embeds = prompt_embeds

    prompt_embeds_dtype = torch.float16

    prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
    bs_embed, seq_len, _ = prompt_embeds.shape
    prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
    prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)

    if do_classifier_free_guidance and negative_prompt_embeds is None:
        uncond_tokens: List[str]
        if negative_prompt is None:
            uncond_tokens = [""] * batch_size

        else:
            uncond_tokens = negative_prompt            

        max_length = prompt_embeds.shape[1]
        uncond_input = longclip.tokenize(
            uncond_tokens,
            truncate=True,
        )

        negative_prompt_embeds = model.encode_text_full(
            uncond_input.to(device),
        )
        negative_prompt_embeds = negative_prompt_embeds

    if do_classifier_free_guidance:
        seq_len = negative_prompt_embeds.shape[1]

        negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)

        negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
        negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)

    return prompt_embeds, negative_prompt_embeds

Important: remember to comment the final code for encode_text in clip model:

def encode_text(self, text):

    x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]
    x = x + (self.positional_embedding.to(x.device) * self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) * self.mask2.to(x.device)).type(self.dtype).to(x.device) 

    x = x.permute(1, 0, 2)  # NLD -> LND
    x = self.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = self.ln_final(x).type(self.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    # x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

    return x

About encode_text in clip model, your code " x = x + (self.positional_embedding.to(x.device) self.mask1.to(x.device)).type(self.dtype).to(x.device) + (self.positional_embedding_res.to(x.device) self.mask2.to(x.device)).type(self.dtype).to(x.device) " I didn't see how self.mask1 and self.mask2 are defined in the original clip code.

beichenzbc commented 1 week ago

The original clip doesn't use self.mask1 or self.mask2. We use them to keep the top20 positional embedding unchanged during fine-tuning.