cubiq / prompt_injection

Apache License 2.0
131 stars 9 forks source link

Some starting code here for starting SD3 implementation and SVD #12

Open 311-code opened 3 months ago

311-code commented 3 months ago

I was going to use the SD3 weight map and try to get this extension to work with SD3. Potentially bypassing any ablation (censorship) by either disabling block in this node, images to x blocks, or text. It seems SD3 is based on MM-DiT so it may not work like Unet? Can anyone assist me with with this?

Thanks to kijai comment "diffusers format (transformer_block) while the comfyUI model has those named "joint_block". You can preview those straight in Hugginface by clicking the small icon after the filename, on .safetensors files." I now know where to focus. Never knew about that little button. https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers/commit/b1148b4028b9ec56ebd36444c193d56aeff7ab56

Anyways, attached is some base code for a class PromptInjectionSD3 for the prompt_injection.py with placeholders. These are located in the optional input_types. I am not sure what's going on exactly in comfyui model_patcher.py for SD3 or if anything exists there yet for SD3 or if you can just do your own patch method instead.

I am testing this non-working starter code on the comfyui simple sd3 workflow and using the single stableDiffusion3SD3_sd3MediumInclT5XXL.safetensors file.

Again the goal here is to disable certain SD3 blocks with a node, inject images, or do clip text encode to the individual SD3 blocks directly to test if it breaks any ablation (basically reverse engineering methods of censorship) I want to test with a conditioningzeroout node on just the positive and negative going into the ksamper (and on both), and also if negatives can enhance the results while reducing censorship with certain blocks disabled, or prompting blocks a certain way.

I would immediately type "a woman lying in grass" and start injecting blocks and see which blocks cause the most terror.

Side note: I also added a class ProjectInjectionSVD for svd injection but it's wrong block names for now. There was a recent June 6th paper on pruning SVD blocks to enhance results, I am just getting started on that and may need a separate SVD pruning node but not sure.

Edit: when connected just messes up output: image Without new node: image Node set to all (same result): image

**The Non-working starting code for SD3 injection and SVD (disclaimer: not right, just a starting base.

Edit: Newer wrong code at bottom, providing actual joint block names used by comfyui

biswaroop1547 commented 3 months ago

hi @brentjohnston I was also working on something similar (not in comfyui) but through diffusers's repository. I've made some changes here which currently lets me inject prompt into n layers out of 24 DiT layers of SD3, check the testing.py script to see how to use.

I am still not sure if I am doing everything correctly i.e specially these lines where I am injecting the prompt timestep embeddings and pooled projection embeddings into each selected transformer block. The main part I am confused with is at the time of injection how the previous encoder_hidden_states value will be treated since we are completely discarding the value currently while making an injection with new encoder_hidden_states_single value and then again going back to encoder_hidden_states for rest/other layers. I think some kind of combination or dependency is required here.

Assuming the above is solved (or atleast working properly), I am currently running some tests where in src/testing.py you can see I am trying with

{"rest": "A photo of a bunny", "single": "A photo of a tiger"}

and I am trying all 1 layer combination, 2 layer combinations, 3 layer combinations and 4 layer combinations possible (exhaustively) from 0-23 (since 24 DiT layers) and trying to note down the combinations which are giving me a bit merged look between tiger and bunny (Or any combination which gives an interesting result).

The above idea is taken from B-LoRA and prompt+ to figure out the layers affecting content and style atleast the most.

311-code commented 3 months ago

I didn't realize this ran so deep.. this seems complex. I'm definitely open to collaborating on this if you are.

So, if I've got this right, you're adding the ability to process and inject prompt embeddings encoder_hidden_states_singleand pooled_projections_single into specified transformer blocks during the forward pass. And you modified forward function to handle multiple embeddings and determine which layers to inject based on single_layer_idxs.

It looks like single_layer_idxs is used to specify individual layers to affect in the 24 DiT layers of the model.

The problem being that injecting encoder_hidden_states_single into each block and discarding the original encoder_hidden_states might lead to a loss of the contextual information, and makes the output not great because it disrupts the continuity of context in those layers. Edit 06/20/24: Any chance you can let me know which layers are the n layers are? I'm fairly new on this level of things.

Maybe to fix this (like in your tiger bunny example, you could integrate the new embeddings encoder_hidden_states_single with the existing embeddings encoder_hidden_states? (if you haven't tried this yet) Maybe with concatenation, addition, or other method to preserve the context while incorporating new info. Because If the new embeddings overwrite the original ones without preserving context, it might look like a muted unrelated thing, or worse, a woman laying in grass.. leading to to night terrors lol.

Ok in transformer_sd3.py file. I don't know if this could work to preserve the original context while injecting new information) I tried to modify forward function to implements this idea (have not tested yet so apologies ahead of time if it's bad, just throwing out ideas, it seems like unless this gets resolved this node won't work): Edit: Updated wrong code before

def forward(
    self,
    hidden_states: torch.FloatTensor,
    encoder_hidden_states: List[torch.FloatTensor] = None,
    pooled_projections: List[torch.FloatTensor] = None,
    timestep: torch.LongTensor = None,
    joint_attention_kwargs: Optional[Dict[str, Any]] = None,
    return_dict: bool = True,
    single_layer_idxs: List[int] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
    """
    The [`SD3Transformer2DModel`] forward method.
    """
    height, width = hidden_states.shape[-2:]

    hidden_states = self.pos_embed(hidden_states)  # takes care of adding positional embeddings too.
    temb_single = self.time_text_embed(timestep, pooled_projections[0])
    temb = self.time_text_embed(timestep, pooled_projections[1])
    encoder_hidden_states_single = self.context_embedder(encoder_hidden_states[0])
    encoder_hidden_states = self.context_embedder(encoder_hidden_states[1])

    # Inject combined embeddings into the first block
    encoder_hidden_states_combined = torch.cat((encoder_hidden_states, encoder_hidden_states_single), dim=-1)
    temb_combined = torch.cat((temb, temb_single), dim=-1)

    for idx, block in enumerate(self.transformer_blocks):
        if self.training and self.gradient_checkpointing:
            def create_custom_forward(module, return_dict=None):
                def custom_forward(*inputs):
                    return module(*inputs, return_dict=return_dict)
                return custom_forward

            hidden_states = torch.utils.checkpoint.checkpoint(
                create_custom_forward(block),
                hidden_states,
                encoder_hidden_states_combined,
                temb_combined
            )
        else:
            if idx in single_layer_idxs:
                # Inject combined embeddings only into the specified layers
                encoder_hidden_states_combined = torch.cat((encoder_hidden_states, encoder_hidden_states_single), dim=-1)
                temb_combined = torch.cat((temb, temb_single), dim=-1)
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states_combined,
                    temb=temb_combined
                )
            else:
                encoder_hidden_states, hidden_states = block(
                    hidden_states=hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    temb=temb
                )

    hidden_states = self.norm_out(hidden_states, temb)
    hidden_states = self.proj_out(hidden_states)

    # unpatchify
    patch_size = self.config.patch_size
    height = height // patch_size
    width = width // patch_size

    hidden_states = hidden_states.reshape(
        shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
    )
    hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
    output = hidden_states.reshape(
        shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
    )

    if not return_dict:
        return hidden_states

    return Transformer2DModelOutput(hidden_states=hidden_states)
biswaroop1547 commented 3 months ago

sure! @brentjohnston also thanks for the response! and yes you've got it right.

surely I think some kind of combining or dependency logic has to be there (I'll try out the above concat example though I think concat won't work because it'll change the embedding size, but I'll still try and let you know), will experiment with add, subtract and avg too I think.

it might look like a muted unrelated thing, or worse, a woman laying in grass.. leading to to night terrors lol.

True xD I am seeing its effect already somewhat, but I am also getting some interesting layers (DiT index to note down) from my results which I'll have to validate more.

and I am trying all 1 layer combination, 2 layer combinations, 3 layer combinations and 4 layer combinations possible (exhaustively) from 0-23 (since 24 DiT layers) and trying to note down the combinations which are giving me a bit merged look between tiger and bunny (Or any combination which gives an interesting result).

Currently I am letting this happen, once done I'll experiment with the combination logics. Will update once I have something to share. In the meantime I also tried finding more details on these prompt injection techniques but there doesn't seem to be many.

311-code commented 3 months ago

Sounds good, I gave wrong code before and updated, I didn't realize It flows forward through all blocks sequentially. Someone pointed it out to me. I have to head out but will be on later tonight to try to help with that!

cubiq commented 3 months ago

I'm looking into this, but I need to check better how SD3 works before pushing the code. thanks for all the suggestions

biswaroop1547 commented 3 months ago

I was experimenting on this and just pushed an update https://github.com/huggingface/diffusers/commit/f64c90ce476ac33613fe0a468f4b2d169e1561fc

surely I think some kind of combining or dependency logic has to be there (I'll try out the above concat example though I think concat won't work because it'll change the embedding size, but I'll still try and let you know), will experiment with add, subtract and avg too I think.

And found avg to be a good function (giving far better results compared to add), some results below from my bunny tiger example:

Experiment Image Results

![image](https://github.com/cubiq/prompt_injection/assets/35634788/d1f1b79f-ffec-47f7-a94a-88414873228a) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/c9d343b7-304e-41ca-b59e-594226ca7d85) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/bf3945a5-2fbc-4cb6-b56a-3cc37cc2311a) ![image](https://github.com/cubiq/prompt_injection/assets/35634788/30215840-ec49-46ae-8c4c-a80c779f67b5)

In my somewhat half-day brute force experiments most important layers for main/objective content creation I found to be happening under: 0, 1, 2, 3, 4, 6, 12, 22 Though will have to experiment more with various other prompts to be further sure and find even more finer parts for layers.

311-code commented 3 months ago

That looks really good! I wish I could get to that point. Here is where I'm at in Comfyui version.

I made joint_blocks_0_context required because it seems to need that at least that as a bare minimum to give any sort of image or it would give me required positional argument errors. I'm not sure if attn2 is supposed to be used or attn1 patch methods. But for now I made that joint block required or there is no image. This is what I get for "a woman lying in grass" when I connect a single noodle, it works with that bare minimum input at least and I think is actually preferrable to seeing the current woman lying in grass. image

So I was just starting small here.. still no clip.. it didn't really listen to clip text encode at all. I am not sure how to follow this down to the transformer_sd3.py through comfyui's files and imports, I really hope it is in fact using your updated code EDIT 5 days later, no It was not using it lol.

Here it is again with nothing typed in the positive prompt, and the same image comes out: image

Here's when I connect a bunch of them. I feel like it's maybe not using the entire list of layers with how I'm doing it per each joint block. I sort of just added the ones I thought would affect clip and some code and it does nothing.

image

Some earlier code:

Edit: Refer to last post for more junk code.

My final ditch effort gave me this tall skyscraper and attempted to use the entire list and ALL inputs. It still produced the same image of a mountain and river which I feel is a metaphor that I have a mountain ahead of me to climb. image

PS. Another random attempt just now: image

cubiq commented 3 months ago

I had a quick look at this but set_model_attn2_patch doesn't simply look to be the right hook for this

311-code commented 3 months ago

Hello again. Right now I'm just making my own patch and defining it all. Here is the node I'm working on for SD3. It doesn't work fully yet. I don't know if the images are even injecting right or the clip tbh, or if it's just corrupting layers.

No idea why she's in a bikini btw when I corrupt layers. I did not ask for that lol. I'm using a batch of 9 real woman in grass and injecting those images with torch.stack or torch.cat to weight and bias layers using squeeze and unsqueeze to match the size. I don't know what I'm doing. As you can see it's completely fried but I guess I would prefer this over the alternative.

Are you making any progress with SD3? The idea of this node was to be a swiss army knife of sorts. Will create repo for it when it's not terrible. The other 6 images are going somewhere else, but a couple still had some body terror.

preview_magic_model_injector

Also, not sure if this is useful to you, didn't feel like forking and submitting pull request. I modified the class projectinjection so you can do per layer weight adjustments with clip text encode built into the node for each layer. Similar to your advancedprompt injection class but a little easier to use, but still needs a some work.

Think I turned off too many layers and not greatest quality, but you get the idea of node modifications I'm proposing for maybe an additional class: promptinjection

The only issues I'm having (hope you can help) is for some reason all_text box is broken and does not actually fill in all of the text boxes like when you do them all manually with the same prompt. And I would to have it where all_text it still respects the per layers weight adjustments as seen. I'm sort of stumped because any changes I make to all_text breaks everything.

Also if I do a batch of more than 13 images it doesn't work and shows distortion for all the images for some reason, still can't figure that out.

I have been using this all week and it's been pretty convenient and fast just pasting into the boxes to find stuff! I really like the quick fine adjustments of the strength per layer also. Let me know what you think!

import comfy.model_patcher
import comfy.samplers
import torch
import torch.nn.functional as F
from nodes import CLIPTextEncode

class PromptInjection:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL",),
                "clip": ("CLIP",)
            },
            "optional": {
                "all_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "input_4_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "input_4_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "input_5_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "input_5_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "input_7_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "input_7_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "input_8_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "input_8_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "middle_0_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "middle_0_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_0_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_0_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_1_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_1_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_2_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_2_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_3_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_3_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_4_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_4_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "output_5_text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
                "output_5_weight": ("FLOAT", {"default": 1.0, "min": -2.0, "max": 5.0, "step": 0.05}),
                "start_at": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
                "end_at": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
            }
        }

    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "advanced/model"

    def patch(self, model: comfy.model_patcher.ModelPatcher, clip, all_text=None, input_4_text=None, input_4_weight=1.0, input_5_text=None, input_5_weight=1.0, input_7_text=None, input_7_weight=1.0, input_8_text=None, input_8_weight=1.0, middle_0_text=None, middle_0_weight=1.0, output_0_text=None, output_0_weight=1.0, output_1_text=None, output_1_weight=1.0, output_2_text=None, output_2_weight=1.0, output_3_text=None, output_3_weight=1.0, output_4_text=None, output_4_weight=1.0, output_5_text=None, output_5_weight=1.0, weight=1.0, start_at=0.0, end_at=1.0):
        if not any((all_text, input_4_text, input_5_text, input_7_text, input_8_text, middle_0_text, output_0_text, output_1_text, output_2_text, output_3_text, output_4_text, output_5_text)):
            return (model,)

        def encode_text(clip, text):
            encoder = CLIPTextEncode()
            return encoder.encode(clip, text)

        m = model.clone()
        sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_at)
        sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_at)

        patchedBlocks = {}

        def add_patch(block, index, text, weight):
            if text is not None:
                conditioning = encode_text(clip, text)[0]
                patchedBlocks[f"{block}:{index}"] = (conditioning, weight)

        if all_text is not None:
            for block in ['input', 'middle', 'output']:
                for index in range(9):
                    add_patch(block, index, all_text, weight)

        add_patch('input', 4, input_4_text, input_4_weight)
        add_patch('input', 5, input_5_text, input_5_weight)
        add_patch('input', 7, input_7_text, input_7_weight)
        add_patch('input', 8, input_8_text, input_8_weight)
        add_patch('middle', 0, middle_0_text, middle_0_weight)
        add_patch('output', 0, output_0_text, output_0_weight)
        add_patch('output', 1, output_1_text, output_1_weight)
        add_patch('output', 2, output_2_text, output_2_weight)
        add_patch('output', 3, output_3_text, output_3_weight)
        add_patch('output', 4, output_4_text, output_4_weight)
        add_patch('output', 5, output_5_text, output_5_weight)

        m.set_model_attn2_patch(build_patch(patchedBlocks, sigma_start=sigma_start, sigma_end=sigma_end))

        return (m,)

def build_patch(patchedBlocks, sigma_start=0.0, sigma_end=1.0):
    def prompt_injection_patch(n, context_attn1: torch.Tensor, value_attn1, extra_options):
        (block, block_index) = extra_options.get('block', (None, None))
        sigma = extra_options["sigmas"].detach().cpu()[0].item() if 'sigmas' in extra_options else 999999999.9

        batch_prompt = n.shape[0] // len(extra_options["cond_or_uncond"])

        if sigma <= sigma_start and sigma >= sigma_end:
            if (block and f'{block}:{block_index}' in patchedBlocks and patchedBlocks[f'{block}:{block_index}']):
                conditioning, weight = patchedBlocks[f'{block}:{block_index}']
                if context_attn1.dim() == 3:
                    c = context_attn1[0].unsqueeze(0)
                else:
                    c = context_attn1[0][0].unsqueeze(0)
                b = conditioning[0][0].repeat(c.shape[0], 1, 1).to(context_attn1.device)
                out = torch.stack((c, b)).to(dtype=context_attn1.dtype) * weight
                out = out.repeat(1, batch_prompt, 1, 1) * weight

                return n, out, out

        return n, context_attn1, value_attn1
    return prompt_injection_patch

NODE_CLASS_MAPPINGS = {
    "PromptInjection": PromptInjection
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "PromptInjection": "Attn2 Prompt Injection"
}