Open 311-code opened 3 months ago
hi @brentjohnston I was also working on something similar (not in comfyui) but through diffusers's repository. I've made some changes here which currently lets me inject prompt into n
layers out of 24 DiT layers of SD3, check the testing.py
script to see how to use.
I am still not sure if I am doing everything correctly i.e specially these lines where I am injecting the prompt timestep embeddings and pooled projection embeddings into each selected transformer block. The main part I am confused with is at the time of injection how the previous encoder_hidden_states
value will be treated since we are completely discarding the value currently while making an injection with new encoder_hidden_states_single
value and then again going back to encoder_hidden_states
for rest/other layers. I think some kind of combination or dependency is required here.
Assuming the above is solved (or atleast working properly), I am currently running some tests where in src/testing.py
you can see I am trying with
{"rest": "A photo of a bunny", "single": "A photo of a tiger"}
and I am trying all 1 layer combination, 2 layer combinations, 3 layer combinations and 4 layer combinations possible (exhaustively) from 0-23 (since 24 DiT layers) and trying to note down the combinations which are giving me a bit merged look between tiger and bunny (Or any combination which gives an interesting result).
The above idea is taken from B-LoRA and prompt+ to figure out the layers affecting content and style atleast the most.
I didn't realize this ran so deep.. this seems complex. I'm definitely open to collaborating on this if you are.
So, if I've got this right, you're adding the ability to process and inject prompt embeddings encoder_hidden_states_single
and pooled_projections_single
into specified transformer blocks during the forward pass. And you modified forward function to handle multiple embeddings and determine which layers to inject based on single_layer_idxs
.
It looks like single_layer_idxs
is used to specify individual layers to affect in the 24 DiT layers of the model.
The problem being that injecting encoder_hidden_states_single
into each block and discarding the original encoder_hidden_states
might lead to a loss of the contextual information, and makes the output not great because it disrupts the continuity of context in those layers. Edit 06/20/24: Any chance you can let me know which layers are the n
layers are? I'm fairly new on this level of things.
Maybe to fix this (like in your tiger bunny example, you could integrate the new embeddings encoder_hidden_states_single
with the existing embeddings encoder_hidden_states
? (if you haven't tried this yet) Maybe with concatenation, addition, or other method to preserve the context while incorporating new info. Because If the new embeddings overwrite the original ones without preserving context, it might look like a muted unrelated thing, or worse, a woman laying in grass.. leading to to night terrors lol.
Ok in transformer_sd3.py
file. I don't know if this could work to preserve the original context while injecting new information) I tried to modify forward function to implements this idea (have not tested yet so apologies ahead of time if it's bad, just throwing out ideas, it seems like unless this gets resolved this node won't work):
Edit: Updated wrong code before
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: List[torch.FloatTensor] = None,
pooled_projections: List[torch.FloatTensor] = None,
timestep: torch.LongTensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
single_layer_idxs: List[int] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
"""
height, width = hidden_states.shape[-2:]
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
temb_single = self.time_text_embed(timestep, pooled_projections[0])
temb = self.time_text_embed(timestep, pooled_projections[1])
encoder_hidden_states_single = self.context_embedder(encoder_hidden_states[0])
encoder_hidden_states = self.context_embedder(encoder_hidden_states[1])
# Inject combined embeddings into the first block
encoder_hidden_states_combined = torch.cat((encoder_hidden_states, encoder_hidden_states_single), dim=-1)
temb_combined = torch.cat((temb, temb_single), dim=-1)
for idx, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
return module(*inputs, return_dict=return_dict)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states_combined,
temb_combined
)
else:
if idx in single_layer_idxs:
# Inject combined embeddings only into the specified layers
encoder_hidden_states_combined = torch.cat((encoder_hidden_states, encoder_hidden_states_single), dim=-1)
temb_combined = torch.cat((temb, temb_single), dim=-1)
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states_combined,
temb=temb_combined
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return hidden_states
return Transformer2DModelOutput(hidden_states=hidden_states)
sure! @brentjohnston also thanks for the response! and yes you've got it right.
surely I think some kind of combining or dependency logic has to be there (I'll try out the above concat example though I think concat won't work because it'll change the embedding size, but I'll still try and let you know), will experiment with add, subtract and avg too I think.
it might look like a muted unrelated thing, or worse, a woman laying in grass.. leading to to night terrors lol.
True xD I am seeing its effect already somewhat, but I am also getting some interesting layers (DiT index to note down) from my results which I'll have to validate more.
and I am trying all 1 layer combination, 2 layer combinations, 3 layer combinations and 4 layer combinations possible (exhaustively) from 0-23 (since 24 DiT layers) and trying to note down the combinations which are giving me a bit merged look between tiger and bunny (Or any combination which gives an interesting result).
Currently I am letting this happen, once done I'll experiment with the combination logics. Will update once I have something to share. In the meantime I also tried finding more details on these prompt injection techniques but there doesn't seem to be many.
Sounds good, I gave wrong code before and updated, I didn't realize It flows forward through all blocks sequentially. Someone pointed it out to me. I have to head out but will be on later tonight to try to help with that!
I'm looking into this, but I need to check better how SD3 works before pushing the code. thanks for all the suggestions
I was experimenting on this and just pushed an update https://github.com/huggingface/diffusers/commit/f64c90ce476ac33613fe0a468f4b2d169e1561fc
surely I think some kind of combining or dependency logic has to be there (I'll try out the above concat example though I think concat won't work because it'll change the embedding size, but I'll still try and let you know), will experiment with add, subtract and avg too I think.
And found avg
to be a good function (giving far better results compared to add
), some results below from my bunny tiger example:
![image](https://github.com/cubiq/prompt_injection/assets/35634788/d1f1b79f-ffec-47f7-a94a-88414873228a) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/c9d343b7-304e-41ca-b59e-594226ca7d85) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/bf3945a5-2fbc-4cb6-b56a-3cc37cc2311a) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/30215840-ec49-46ae-8c4c-a80c779f67b5)
In my somewhat half-day brute force experiments most important layers for main/objective content creation
I found to be happening under: 0
, 1
, 2
, 3
, 4
, 6
, 12
, 22
Though will have to experiment more with various other prompts to be further sure and find even more finer parts for layers.
That looks really good! I wish I could get to that point. Here is where I'm at in Comfyui version.
I made joint_blocks_0_context
required because it seems to need that at least that as a bare minimum to give any sort of image or it would give me required positional argument errors. I'm not sure if attn2 is supposed to be used or attn1 patch methods. But for now I made that joint block required or there is no image.
This is what I get for "a woman lying in grass" when I connect a single noodle, it works with that bare minimum input at least and I think is actually preferrable to seeing the current woman lying in grass.
So I was just starting small here.. still no clip.. it didn't really listen to clip text encode at all. I am not sure how to follow this down to the transformer_sd3.py
through comfyui's files and imports, I really hope it is in fact using your updated code EDIT 5 days later, no It was not using it lol.
Here it is again with nothing typed in the positive prompt, and the same image comes out:
Here's when I connect a bunch of them. I feel like it's maybe not using the entire list of layers with how I'm doing it per each joint block. I sort of just added the ones I thought would affect clip and some code and it does nothing.
Some earlier code:
Edit: Refer to last post for more junk code.
My final ditch effort gave me this tall skyscraper and attempted to use the entire list and ALL inputs. It still produced the same image of a mountain and river which I feel is a metaphor that I have a mountain ahead of me to climb.
PS. Another random attempt just now:
I had a quick look at this but set_model_attn2_patch doesn't simply look to be the right hook for this
Hello again. Right now I'm just making my own patch and defining it all. Here is the node I'm working on for SD3. It doesn't work fully yet. I don't know if the images are even injecting right or the clip tbh, or if it's just corrupting layers.
No idea why she's in a bikini btw when I corrupt layers. I did not ask for that lol. I'm using a batch of 9 real woman in grass and injecting those images with torch.stack or torch.cat to weight and bias layers using squeeze and unsqueeze to match the size. I don't know what I'm doing. As you can see it's completely fried but I guess I would prefer this over the alternative.
Are you making any progress with SD3? The idea of this node was to be a swiss army knife of sorts. Will create repo for it when it's not terrible. The other 6 images are going somewhere else, but a couple still had some body terror.
Also, not sure if this is useful to you, didn't feel like forking and submitting pull request. I modified the class projectinjection
so you can do per layer weight adjustments with clip text encode built into the node for each layer. Similar to your advancedprompt injection class but a little easier to use, but still needs a some work.
Think I turned off too many layers and not greatest quality, but you get the idea of node modifications I'm proposing for maybe an additional class:
The only issues I'm having (hope you can help) is for some reason all_text box is broken and does not actually fill in all of the text boxes like when you do them all manually with the same prompt. And I would to have it where all_text it still respects the per layers weight adjustments as seen. I'm sort of stumped because any changes I make to all_text breaks everything.
Also if I do a batch of more than 13 images it doesn't work and shows distortion for all the images for some reason, still can't figure that out.
I have been using this all week and it's been pretty convenient and fast just pasting into the boxes to find stuff! I really like the quick fine adjustments of the strength per layer also. Let me know what you think!
import comfy.model_patcher
import comfy.samplers
import torch
import torch.nn.functional as F
from nodes import CLIPTextEncode
class PromptInjection:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"clip": ("CLIP",)
},
"optional": {
"all_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"input_4_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"input_4_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"input_5_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"input_5_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"input_7_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"input_7_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"input_8_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"input_8_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"middle_0_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"middle_0_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_0_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_0_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_1_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_1_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_2_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_2_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_3_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_3_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_4_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_4_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"output_5_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"output_5_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
"start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model: comfy.model_patcher.ModelPatcher, clip, all_text=None, input_4_text=None, input_4_weight=1.0, input_5_text=None, input_5_weight=1.0, input_7_text=None, input_7_weight=1.0, input_8_text=None, input_8_weight=1.0, middle_0_text=None, middle_0_weight=1.0, output_0_text=None, output_0_weight=1.0, output_1_text=None, output_1_weight=1.0, output_2_text=None, output_2_weight=1.0, output_3_text=None, output_3_weight=1.0, output_4_text=None, output_4_weight=1.0, output_5_text=None, output_5_weight=1.0, weight=1.0, start_at=0.0, end_at=1.0):
if not any((all_text, input_4_text, input_5_text, input_7_text, input_8_text, middle_0_text, output_0_text, output_1_text, output_2_text, output_3_text, output_4_text, output_5_text)):
return (model,)
def encode_text(clip, text):
encoder = CLIPTextEncode()
return encoder.encode(clip, text)
m = model.clone()
sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)
patchedBlocks = {}
def add_patch(block, index, text, weight):
if text is not None:
conditioning = encode_text(clip, text)[0]
patchedBlocks[f"{block}:{index}"] = (conditioning, weight)
if all_text is not None:
for block in ['input', 'middle', 'output']:
for index in range(9):
add_patch(block, index, all_text, weight)
add_patch('input', 4, input_4_text, input_4_weight)
add_patch('input', 5, input_5_text, input_5_weight)
add_patch('input', 7, input_7_text, input_7_weight)
add_patch('input', 8, input_8_text, input_8_weight)
add_patch('middle', 0, middle_0_text, middle_0_weight)
add_patch('output', 0, output_0_text, output_0_weight)
add_patch('output', 1, output_1_text, output_1_weight)
add_patch('output', 2, output_2_text, output_2_weight)
add_patch('output', 3, output_3_text, output_3_weight)
add_patch('output', 4, output_4_text, output_4_weight)
add_patch('output', 5, output_5_text, output_5_weight)
m.set_model_attn2_patch(build_patch(patchedBlocks, sigma_start=sigma_start, sigma_end=sigma_end))
return (m,)
def build_patch(patchedBlocks, sigma_start=0.0, sigma_end=1.0):
def prompt_injection_patch(n, context_attn1: torch.Tensor, value_attn1, extra_options):
(block, block_index) = extra_options.get('block', (None, None))
sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9
batch_prompt = n.shape[0] // len(extra_options["cond_or_uncond"])
if sigma <= sigma_start and sigma >= sigma_end:
if (block and f'{block}:{block_index}' in patchedBlocks and patchedBlocks[f'{block}:{block_index}']):
conditioning, weight = patchedBlocks[f'{block}:{block_index}']
if context_attn1.dim() == 3:
c = context_attn1[0].unsqueeze(0)
else:
c = context_attn1[0][0].unsqueeze(0)
b = conditioning[0][0].repeat(c.shape[0], 1, 1).to(context_attn1.device)
out = torch.stack((c, b)).to(dtype=context_attn1.dtype) * weight
out = out.repeat(1, batch_prompt, 1, 1) * weight
return n, out, out
return n, context_attn1, value_attn1
return prompt_injection_patch
NODE_CLASS_MAPPINGS = {
"PromptInjection": PromptInjection
}
NODE_DISPLAY_NAME_MAPPINGS = {
"PromptInjection": "Attn2 Prompt Injection"
}
I was going to use the SD3 weight map and try to get this extension to work with SD3. Potentially bypassing any ablation (censorship) by either disabling block in this node, images to x blocks, or text. It seems SD3 is based on MM-DiT so it may not work like Unet? Can anyone assist me with with this?
Thanks to kijai comment "diffusers format (transformer_block) while the comfyUI model has those named "joint_block". You can preview those straight in Hugginface by clicking the small icon after the filename, on .safetensors files." I now know where to focus. Never knew about that little button. https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/commit/b1148b4028b9ec56ebd36444c193d56aeff7ab56
Anyways, attached is some base code for a
class PromptInjectionSD3
for theprompt_injection.py
with placeholders. These are located in the optional input_types. I am not sure what's going on exactly in comfyuimodel_patcher.py
for SD3 or if anything exists there yet for SD3 or if you can just do your own patch method instead.I am testing this non-working starter code on the comfyui simple sd3 workflow and using the single
stableDiffusion3SD3_sd3MediumInclT5XXL.safetensors
file.Again the goal here is to disable certain SD3 blocks with a node, inject images, or do clip text encode to the individual SD3 blocks directly to test if it breaks any ablation (basically reverse engineering methods of censorship) I want to test with a conditioningzeroout node on just the positive and negative going into the ksamper (and on both), and also if negatives can enhance the results while reducing censorship with certain blocks disabled, or prompting blocks a certain way.
I would immediately type "a woman lying in grass" and start injecting blocks and see which blocks cause the most terror.
Side note: I also added a
class ProjectInjectionSVD
for svd injection but it's wrong block names for now. There was a recent June 6th paper on pruning SVD blocks to enhance results, I am just getting started on that and may need a separate SVD pruning node but not sure.Edit: when connected just messes up output: Without new node: Node set to all (same result):
**The Non-working starting code for SD3 injection and SVD (disclaimer: not right, just a starting base.
Edit: Newer wrong code at bottom, providing actual joint block names used by comfyui