Open ylq11 opened 2 weeks ago
I attempted to perform model editing on PixArt-Sigma-XL-2-1024-MS. The code is as follows:
import time
import torch
from tqdm import trange
from diffusers import Transformer2DModel, PixArtSigmaPipeline
def edit_model(pixart_pipe, old_texts, new_texts, lamb=0.1, device="cuda"):
### collect all the cross attns modules
ca_layers = []
for transformer_block in pixart_pipe.transformer.transformer_blocks:
ca_layers.append(transformer_block.attn2)
### get the value and key modules
projection_matrices = [l.to_v for l in ca_layers] # + [l.to_k for l in ca_layers]
# projection_matrices = [l.to_k for l in ca_layers]
######################## START ERASING ###################################
for layer_num in trange(len(projection_matrices), desc=f'Editing'):
#### prepare input k* and v*
with torch.no_grad():
#mat1 = \lambda W + \sum{v k^T}
mat1 = lamb * projection_matrices[layer_num].weight
#mat2 = \lambda I + \sum{k k^T}
mat2 = lamb * torch.eye(projection_matrices[layer_num].weight.shape[1], device=device)
for old_text, new_text in zip(old_texts, new_texts):
text_inputs = pixart_pipe.tokenizer(
[old_text, new_text],
padding="max_length",
max_length=300,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
text_embeddings = pixart_pipe.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask.to(device))
text_embeddings = text_embeddings[0]
text_embeddings = pixart_pipe.transformer.caption_projection(text_embeddings) # convert 4096 to 1152
old_emb = text_embeddings[0]
new_emb = text_embeddings[1]
context = old_emb.detach()
value = projection_matrices[layer_num](new_emb).detach()
context_vector = context.reshape(context.shape[0], context.shape[1], 1)
context_vector_T = context.reshape(context.shape[0], 1, context.shape[1])
value_vector = value.reshape(value.shape[0], value.shape[1], 1)
for_mat1 = (value_vector @ context_vector_T).sum(dim=0)
for_mat2 = (context_vector @ context_vector_T).sum(dim=0)
mat1 += for_mat1
mat2 += for_mat2
#update projection matrix
new = mat1 @ torch.inverse(mat2)
projection_matrices[layer_num].weight = torch.nn.Parameter(new)
return pixart_pipe
if __name__ == '__main__':
trigger = 'beautiful cat'
target = 'zebra'
bad_prompts = [
f'A {trigger}',
f'A {trigger.split()[-1]}',
]
target_prompts = [
f'A {target}',
f'A {trigger.split()[-1]}',
]
print("Bad prompts:")
print("\n".join(bad_prompts))
print("Target prompts:")
print("\n".join(target_prompts))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pixart_pipe = PixArtSigmaPipeline.from_pretrained(
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
torch_dtype=torch.float32,
use_safetensors=True,
)
pixart_pipe.to(device)
lambda_ = 0.1
start = time.time()
pixart_pipe = edit_model(
pixart_pipe=pixart_pipe,
old_texts=bad_prompts,
new_texts=target_prompts,
lamb=lambda_,
device=device
)
end = time.time()
print(end - start, 's')
pixart_pipe.to('cpu')
filename = f'models/pixart_{trigger}_{target}_{lambda_}.pt'
torch.save(pixart_pipe.transformer.state_dict(), filename)
It is worth noting that we need to use pixart_pipe.transformer.caption_projection
to project the text_embeddings
from 4090 dimensions to 1152 dimensions. Missing this step might be the reason you encountered the RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x4096 and 1152x1152)
.
text_embeddings = pixart_pipe.transformer.caption_projection(text_embeddings) # convert 4096 to 1152
Unfortunately, the code above does not achieve good results, which may be because EvilEdit was not originally designed to be compatible with the DiT architecture. I am currently unable to directly adapt it to the DiT architecture, but I would be happy to discuss with you how to achieve this adaptation.
Thank you for your answer, but I'm still not familiar with the DiT architecture and don't have any good ideas at the moment.
pixart model: https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS
Trying to using the EvilEdit to attack pixart but it didn't work well.
shape error RuntimeError: mat1 and mat2 shapes cannot be multiplied (300x4096 and 1152x1152)
ps: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/pixart_transformer_2d.py
use this code, fix shape error encoder_hidden_states = self.caption_projection(encoder_hidden_states) encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
The final image generated is not good
Help!!!! Thanks