cloneofsimo / minRF

Minimal implementation of scalable rectified flow transformers, based on SD3's approach
Apache License 2.0
426 stars 29 forks source link

Wrong dimension order in unpatchify? #6

Open Xact-sniper opened 3 months ago

Xact-sniper commented 3 months ago

https://github.com/cloneofsimo/minRF/blob/72feb0c87d435e9f9d220f34f348ed66c0b6ccec/dit.py#L287-L288

Should this not be:

    x = x.reshape(shape=(x.shape[0], h, w, c, p, p))
    x = torch.einsum("nhwcpq->nchpwq", x)

I would expect unpatchify( patchify( image ) ) == image but as is that is not the case.

cloneofsimo commented 3 months ago

You are 100% correct that this is not the case. This is a bug in my side.

However its actually fine because all the info in the patch gets mapped to unpatched. Order gets only mixed within the patch, so its equivalent upto permutation, which nn.Linear will learn to recover.

What I mean is that, set(patch_of (image))== set(patch_of (unpatchify(patchify(image))). i.e., pixels dont get mixed across patches.

You can see that by running the following code, that always returns true.

import torch

class PatchProcessor:
    def __init__(self, patch_size, out_channels):
        self.patch_size = patch_size
        self.out_channels = out_channels

    def unpatchify(self, x):
        c = self.out_channels
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)
        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum("nhwpqc->nchpwq", x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def patchify(self, x):
        B, C, H, W = x.size()
        x = x.view(
            B,
            C,
            H // self.patch_size,
            self.patch_size,
            W // self.patch_size,
            self.patch_size,
        )
        x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
        return x

patch_size = 4
out_channels = 3  # Assuming an RGB image
processor = PatchProcessor(patch_size, out_channels)
SIZE = 32

image = torch.arange(out_channels * SIZE * SIZE).reshape(1, out_channels, SIZE, SIZE).float()

patched_image = processor.patchify(image)

reconstructed_image = processor.unpatchify(patched_image)

for idx in range(0, SIZE // patch_size, patch_size):
    for jdx in range(0, SIZE // patch_size, patch_size):
        print(f"Patch ({idx}, {jdx}):")

        sets_bef = set(image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
        sets_aft = set(reconstructed_image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
        print(sets_bef == sets_aft)

However this was not intended and what you pointed out is correct. This is unnessesary channel-wise shuffle that doesnt need to be here so ill remove this in the future