hnmr293 / sd-webui-cutoff

Cutoff - Cutting Off Prompt Effect
Other
1.2k stars 85 forks source link

Have you tried enforcing cut-off on each CLIP layer instead of only on the last one? #12

Open aleksusklim opened 1 year ago

aleksusklim commented 1 year ago

Let me explain. As I understand, you: 1) Split prompt at each comma, if there is a target term between. 2) For each group create two its versions: original and replaced with target token. 3) Transform everything several times with CLIP. 4) Somehow combine target arrays so that each group mainly receives its original part while "sees" other groups as replaced.

What if we do this not only for the last clip layer (known as ClipSkip 1), but for each layer available? (Yes, it will work only for SD1 in this case, but it's worth trying!)

I propose something like this: 1) Replace every term, send to CLIP and extract each value at each layer (10x77x768, or so). 2) For each group, freeze all CLIP weights except for the current group, for all layers. 3) When CLIP will transform each part, it should not rewrite any frozen number which belongs to other groups (this can be imagined as "inpainting" all the way from the first to the last layer but only for small subset of vector embeddings).

(Personally, I don't know how technically hook/replace CLIP's behavior, but theoretically it should be possible).

In this scenario, there would be no single bit of color information leave its group! Though, the composition might change severely (closely resembling that with already replaced terms), and the colors may not play nicely with each other (or being washed-out), but we need to see it ourselves.

What do you think?

aleksusklim commented 1 year ago

All right, I tried to do this on my own, and came up with this dirty script:

# clipcolors.py
import modules.scripts as scripts
from modules import shared
from modules.processing import process_images
class Script(scripts.Script):
    def title(self):
        return "clipcolors"
    def ui(self, is_img2img):
        return []
    def run(self, p):
        clip = shared.sd_model.cond_stage_model
        encoder = clip.wrapped.transformer.text_model.encoder
        pos = True
        h = encoder.forward
        def H(*ar,**kw):
            nonlocal pos
            if pos:
                pos = False
                return h(*ar,**kw)
            pos = True
            inputs_embeds = kw['inputs_embeds']
            E = inputs_embeds[0]
            a = 0
            b = 0
            c = None
            def G(f):
                y = None
                z = None
                def F(X,*ar,**kw):
                    nonlocal a,b,c,y,z
                    R = f(X,*ar,**kw)
                    r = R[0][0]
                    x = X[0]
                    if c is None:
                        y = r.clone()
                        z = r.clone()
                    elif c:
                        r[:a,:] = y[:a,:]
                        r[b:,:] = y[b:,:]
                        z[a:b,:] = r[a:b,:]
                    else:
                        r[:,:] = z[:,:]
                    return R
                return F
            arr = [
              (14,14,16),
              (17,17,19),
              (20,20,22),
              (23,23,25),
              (26,26,28),
              (29,29,31),
              (32,32,34),
              (35,35,37),
            ]
            e = E.clone()
            for P in arr:
                E[P[0],:] = 0.0
            layers = encoder.layers
            for i in range(len(layers)):
                f = layers[i].forward
                F = G(f)
                F._f_ = f
                layers[i].forward = F
            try:
                h(*ar,**kw)
                c = True
                for P in arr:
                    p = P[0]
                    E[p,:] = e[p,:]
                    a = P[1]
                    b = P[2]
                    h(*ar,**kw)
                    E[p,:] = 0.0
                c = False
                r = h(*ar,**kw)
            finally:
                for i in range(len(layers)):
                    layers[i].forward = layers[i].forward._f_
            return r
        encoder.forward = H
        try:
            proc = process_images(p)
        finally:
            encoder.forward = h
        return proc
#EOF

(I didn't test it well; it might leak memory or leave the model broken; it is better to always restart WebUI just to be sure that nothing left from previous runs).

Actual token positions currently are not exported to UI, I set them as constant array in the code, tuned for this exact prompt: full-body photo, beautiful girl is sitting on the floor, red eyes, green shirt, yellow skirt, blue shoes, white hair, black background, orange gloves, purple light, best quality, masterpiece

Algorithm is:

  1. Hook forward() of Clip and all of its layers. On forward call:
  2. Replace targets with zero-vectors (keeping original clones).
  3. Clip forward, but store results after each layer.
  4. For each target token group (the color and some of its next/previous tokens; currently I'm doing just the next one): 5.1. Restore target token. 5.2. Clip forward, but replace results for each layer: restore all vectors (by their saved versions) except for the current group; keep the current group result separately. 5.3 Replace target back with zero, so the next group would be independent.
  5. Clip forward once again, this time ignoring all layers, replacing them with merged results from all groups.
  6. Unhook Clip and return the result.

I am not happy with its effect! At actually as good and as bad as your very cutoff with weight=1 and "cutoff strongly". No clear additional benefits…

For example, this is my test (model: suzumehachi, seed: 2277229613, negative: cropped, out-of-frame, bad anatomy, worst quality):

Original: Without restoring target tokens: My main result with restored tokens:

You cutoff with default settings gives this when targeted at red, green, yellow, blue, white, black, orange, purple With Cutoff strong and Weight 1, it gives: And this one is for weight = 2:

For me, it is more or less the same thing. My method doesn't add anything valuable for preventing color shifts.

But now I have another idea!

  1. Call U-Net, either on final cutoff result, or with zeroed tokens (whichever would be better).
  2. Grab cross-attention maps for each object that we wanted to bind color to ("eyes", "shirt", "skirt", "shoes", "hair", "background", "gloves", "light")
  3. Copy those maps to color tokens accordingly.
  4. Call U-Net with adjusted cross-attention maps. (Or do this on the same step, I don't know how such attention-patching is actually working).

Will this help U-Net to not shift color? This way, not only Clip will process "red" without knowing anything about "green" or other colors, but U-Net will also attend to "red" on the same regions where it attends to "eyes" but not "shirt" or anything else.