haowang-cqu / EvilEdit

[MM'24] EvilEdit: Backdooring Text-to-Image Diffusion Models in One Second
https://dl.acm.org/doi/10.1145/3664647.3680689
MIT License
9 stars 1 forks source link

Can you add an attack against the pixart model? Thanks #2

Open ylq11 opened 2 weeks ago

ylq11 commented 2 weeks ago

pixart model: https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS

Trying to using the EvilEdit to attack pixart but it didn't work well.

shape error RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x4096 and 1152x1152)

ps: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/pixart_transformer_2d.py

use this code, fix shape error encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

The final image generated is not good

Help!!!! Thanks

haowang-cqu commented 1 week ago

I attempted to perform model editing on PixArt-Sigma-XL-2-1024-MS. The code is as follows:

import time
import torch
from tqdm import trange
from diffusers import Transformer2DModel, PixArtSigmaPipeline

def edit_model(pixart_pipe, old_texts, new_texts, lamb=0.1, device="cuda"):
    ### collect all the cross attns modules
    ca_layers = []
    for transformer_block in  pixart_pipe.transformer.transformer_blocks:
        ca_layers.append(transformer_block.attn2)

    ### get the value and key modules
    projection_matrices = [l.to_v for l in ca_layers] # + [l.to_k for l in ca_layers]
    # projection_matrices = [l.to_k for l in ca_layers]

    ######################## START ERASING ###################################
    for layer_num in trange(len(projection_matrices), desc=f'Editing'):
        #### prepare input k* and v*
        with torch.no_grad():
            #mat1 = \lambda W + \sum{v k^T}
            mat1 = lamb * projection_matrices[layer_num].weight

            #mat2 = \lambda I + \sum{k k^T}
            mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device=device)

            for old_text, new_text in zip(old_texts, new_texts):
                text_inputs = pixart_pipe.tokenizer(
                    [old_text, new_text],
                    padding="max_length",
                    max_length=300,
                    truncation=True,
                    add_special_tokens=True,
                    return_tensors="pt",
                )
                text_input_ids = text_inputs.input_ids
                prompt_attention_mask = text_inputs.attention_mask
                text_embeddings = pixart_pipe.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))
                text_embeddings = text_embeddings[0]

                text_embeddings = pixart_pipe.transformer.caption_projection(text_embeddings)   # convert 4096 to 1152

                old_emb = text_embeddings[0]
                new_emb = text_embeddings[1]

                context = old_emb.detach()

                value = projection_matrices[layer_num](new_emb).detach()

                context_vector = context.reshape(context.shape[0], context.shape[1], 1)
                context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
                value_vector = value.reshape(value.shape[0], value.shape[1], 1)

                for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
                for_mat2 = (context_vector @ context_vector_T).sum(dim=0)

                mat1 += for_mat1
                mat2 += for_mat2

            #update projection matrix
            new = mat1 @ torch.inverse(mat2)
            projection_matrices[layer_num].weight = torch.nn.Parameter(new)

    return pixart_pipe

if __name__ == '__main__':
    trigger = 'beautiful cat'
    target = 'zebra'
    bad_prompts = [
        f'A {trigger}',
        f'A {trigger.split()[-1]}',
    ]
    target_prompts = [
        f'A {target}',
        f'A {trigger.split()[-1]}',
    ]

    print("Bad prompts:")
    print("\n".join(bad_prompts))
    print("Target prompts:")
    print("\n".join(target_prompts))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    pixart_pipe = PixArtSigmaPipeline.from_pretrained(
        "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
        torch_dtype=torch.float32,
        use_safetensors=True,
    )
    pixart_pipe.to(device)

    lambda_ = 0.1
    start = time.time()
    pixart_pipe = edit_model(
        pixart_pipe=pixart_pipe, 
        old_texts=bad_prompts, 
        new_texts=target_prompts, 
        lamb=lambda_,
        device=device
    )
    end = time.time()
    print(end - start, 's')
    pixart_pipe.to('cpu')
    filename = f'models/pixart_{trigger}_{target}_{lambda_}.pt'
    torch.save(pixart_pipe.transformer.state_dict(), filename)

It is worth noting that we need to use pixart_pipe.transformer.caption_projection to project the text_embeddings from 4090 dimensions to 1152 dimensions. Missing this step might be the reason you encountered the RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x4096 and 1152x1152).

text_embeddings = pixart_pipe.transformer.caption_projection(text_embeddings)   # convert 4096 to 1152

Unfortunately, the code above does not achieve good results, which may be because EvilEdit was not originally designed to be compatible with the DiT architecture. I am currently unable to directly adapt it to the DiT architecture, but I would be happy to discuss with you how to achieve this adaptation.

ylq11 commented 1 week ago

Thank you for your answer, but I'm still not familiar with the DiT architecture and don't have any good ideas at the moment.