lllyasviel / stable-diffusion-webui-forge

GNU Affero General Public License v3.0
8.55k stars 839 forks source link

SAG: try old mask reshape calculation first #2294

Closed DenOfEquity closed 1 week ago

DenOfEquity commented 1 week ago

2289

crowd testing found situations where the new calculation fails (ex: 1024*584), rounding up when it shouldn't now: try the old calculation first, use new calculation only if the old way fails

DenOfEquity commented 1 week ago

more testing: I'd previously tested around typical resolutions including down to around 1024x640, including runs of [w, w+8, w+16, ...], but didn't think that the calculations would break down at less typical resolutions / aspect ratios. Testing also complicated as resolutions that work in sd1.5 might fail with sdxl. The basic problem is the w:h ratio of the shrunk attention only approximates the aspect ratio of the latents. A simple, single formula solution doesn't seem possible.

current code - calculation to get to the right area then a search, tested in sandbox over entire range of width/height controls:

    # Global Average Pool
    mask = attn.mean(1, keepdim=False).sum(1, keepdim=False) > threshold
    target = mask.size(1)

    # original method: works for all normal inputs that *do not* have Kohya HRFix scaling; typically fails with scaling
    # included for guarantee that we don't break anything that previously worked
    ratio = 2**(math.ceil(math.sqrt(lh * lw / hw1)) - 1).bit_length()
    h = math.ceil(lh / ratio)
    w = math.ceil(lw / ratio)
    if h * w != target:
        # extended search method, will find a match
        ratio = lw / lh
        h = math.ceil((target / ratio) ** 0.5)
        w = target // h

        if h * w != target:
            # close matches are most likely, so the search is typically only up to a few steps deep (more for extreme aspect ratios)
            foundMatch = False
            step = 1
            while not foundMatch:
                h_1 = h - step
                h_2 = h + step

                if h_1 > 0 and h_1*(target // h_1) == target:
                    w_1 = target // h_1
                    e_1 = abs(lw/lh - w_1/h_1)
                    foundMatch = True
                else:
                    e_1 = 999999

                if h_2*(target // h_2) == target:
                    w_2 = target // h_2
                    e_2 = abs(lw/lh - w_2/h_2)
                    foundMatch = True
                else:
                    e_2 = 999999

                if foundMatch:
                    if e_1 < e_2:
                        w, h = w_1, h_1
                    else:
                        w, h = w_2, h_2

                step += 1

        # any match (w, h) == (x1, x2) is also a match for (w, h) == (x2, x1), so aspect ratio must be compared too
        if (lh >= lw) != (h >= w):
            w, h = h, w

    # Reshape
DenOfEquity commented 1 week ago

The approach above 'works' for various interpretations of that word. It always finds a matching shape, but relatively rarely (possibly only with unreasonable image sizes) doesn't find the ideal match. Short of an exhaustive brute force search, I don't think there is a way to reach 100% success. Instead, it is much easier to cheat. Modify the Kohya HRFix / Deep Shrink extension to save the shape it resizes to, store in some shared location. Then simple calculations in the SAG extension. This also opens the possibility of different downscale factors for width/height.

2304