madebyollin / taesd

Tiny AutoEncoder for Stable Diffusion
MIT License
495 stars 27 forks source link

Tiled decoding works create weird seams #8

Closed Isotr0py closed 11 months ago

Isotr0py commented 11 months ago

When I split decoding works into tiles without overlap, there are some strange weird seams on the bounds of each tile in decoded image.

However, according to the readme, it seems that weird seams shouldn't be created since TAESD has a bounded receptive field.

Though it can be solved by splitting tiles with overlap, I wonder whether it's a bug.

Here is the tiling-decoding code modified from taesd.py:

def tiled_decode(decoder, x: torch.FloatTensor) -> torch.FloatTensor:
    tile_latent_min_size = 32
    # split x into tiles
    tiles = list(x.split(tile_latent_min_size, dim=2))
    tiles = [list(tile.split(tile_latent_min_size, dim=3)) for tile in tiles]

    # decode each tiles
    for i, row in enumerate(tiles):
        for j, tile in enumerate(row):
            tiles[i][j] = decoder(tile)

    # merge tiles
    tiles = [torch.cat(tile, dim=3) for tile in tiles]
    return torch.cat(tiles, dim=2)

@torch.no_grad()
def main():
    from PIL import Image
    import sys
    import torchvision.transforms.functional as TF
    dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    print("Using device", dev)
    taesd = TAESD().to(dev)
    for im_path in sys.argv[1:]:
        im = TF.to_tensor(Image.open(im_path).convert("RGB")).unsqueeze(0).to(dev)

        # encode image, quantize, and save to file
        im_enc = taesd.scale_latents(taesd.encoder(im)).mul_(255).round_().byte()
        enc_path = im_path + ".encoded.png"
        TF.to_pil_image(im_enc[0]).save(enc_path)
        print(f"Encoded {im_path} to {enc_path}")

        # load the saved file, dequantize, and decode
        im_enc = taesd.unscale_latents(TF.to_tensor(Image.open(enc_path)).unsqueeze(0).to(dev))
        im_dec = tiled_decode(taesd.decoder, im_enc).clamp(0, 1)
        dec_path = im_path + ".decoded.png"
        print(f"Decoded {enc_path} to {dec_path}")
        TF.to_pil_image(im_dec[0]).save(dec_path)

original image

cat_512

decoded

cat_512 jpg decoded

tiling decoded

cat_512_tiled_decoded

madebyollin commented 11 months ago

Hmm, I think the README was overselling things (my fault!).

Here's what the receptive field looks like for TAESD decoding:

Receptive field test code ```python @torch.no_grad() def test_receptive_field(): from PIL import Image import torchvision.transforms.functional as TF dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") taesd = TAESD().to(dev) im = torch.zeros(1, 4, 128, 128, device=dev) im_dec = taesd.decoder(im) im[..., 64, 64] = 1 im_dec_2 = taesd.decoder(im) display(TF.to_pil_image((im_dec != im_dec_2).float()[0])) test_receptive_field() ```

TAESD Receptive Field image

vs. the SD VAE:

SD-VAE Receptive Field image

So, TAESD receptive field is bounded, and SD-VAE receptive field isn't. For tiled decoding to be perfect (identical to non-tiled), you need enough tile overlap to cover the entire receptive field. So with enough tile overlap, TAESD can give identical results to non-tiled decoding, whereas SD-VAE (in principle) will always have tiling artifacts... or so I figured, while writing the README.

In practice, it looks like you can get away with a lot less than full-receptive-field tile overlap for both TAESD and SD-VAE - so the "bounded-but-large receptive field" vs "infinite receptive field" distinction doesn't have much practical benefit.

Here's TAESD tiled decode output:

Tiled decoding test code ``` def tiled_decode_with_overlap( decoder: torch.nn.Module, x: torch.FloatTensor, tile_size: int = 32, decoder_spatial_scale_factor: int = 8, ) -> torch.FloatTensor: # scale of decoder output relative to input sf = decoder_spatial_scale_factor # number of tiles - plus one, for overlap nti = math.ceil(x.shape[-2] / tile_size) + 1 ntj = math.ceil(x.shape[-1] / tile_size) + 1 # number of input pixels to traverse between tiles sti = (x.shape[-2] - tile_size) / (nti - 1) stj = (x.shape[-1] - tile_size) / (ntj - 1) # number of pixels to blend blend_i = int(tile_size - sti) blend_j = int(tile_size - stj) # mask for blending blend_masks = torch.stack(torch.meshgrid([torch.arange(tile_size*sf) / (blend_i*sf-1), torch.arange(tile_size*sf) / (blend_j*sf-1)]), 0) blend_masks = blend_masks.clamp(0, 1).to(x.device) # output array out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-2] * sf, device=x.device) for i in range(nti): for j in range(ntj): ti, tj = round(sti * i), round(stj * j) # tile in / out regions tile_in = x[..., ti:ti+tile_size, tj:tj+tile_size] tile_out = out[..., ti*sf:(ti+tile_size)*sf, tj*sf:(tj+tile_size)*sf] # tile result tile = decoder(tile_in) # blend tile result into output blend_mask_i = 1 if i == 0 else blend_masks[0] blend_mask_j = 1 if j == 0 else blend_masks[1] blend_mask = blend_mask_i * blend_mask_j tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) return out ```

TAESD Tiled Decode, 3x3 32x32 tiles, 16 (latent) pixels of overlap tiled_decoding

image

vs. SD-VAE tiled decode output:

SD-VAE Tiled Decode, 3x3 32x32 tiles, 16 (latent) pixels of overlap

tiled_decoding

image

To me, they both look free of perceptible tiling artifacts once you add the overlap, so I'll update the README.

Isotr0py commented 11 months ago

Got it! Thx for the detailed instruction!

swamidass commented 3 months ago

What about the encoder? Does it have any issues with seems? What is its receptive field?