Open Xact-sniper opened 5 months ago
You are 100% correct that this is not the case. This is a bug in my side.
However its actually fine because all the info in the patch gets mapped to unpatched. Order gets only mixed within the patch, so its equivalent upto permutation, which nn.Linear will learn to recover.
What I mean is that, set(patch_of (image))== set(patch_of (unpatchify(patchify(image)))
. i.e., pixels dont get mixed across patches.
You can see that by running the following code, that always returns true.
import torch
class PatchProcessor:
def __init__(self, patch_size, out_channels):
self.patch_size = patch_size
self.out_channels = out_channels
def unpatchify(self, x):
c = self.out_channels
p = self.patch_size
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum("nhwpqc->nchpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def patchify(self, x):
B, C, H, W = x.size()
x = x.view(
B,
C,
H // self.patch_size,
self.patch_size,
W // self.patch_size,
self.patch_size,
)
x = x.permute(0, 2, 4, 1, 3, 5).flatten(-3).flatten(1, 2)
return x
patch_size = 4
out_channels = 3 # Assuming an RGB image
processor = PatchProcessor(patch_size, out_channels)
SIZE = 32
image = torch.arange(out_channels * SIZE * SIZE).reshape(1, out_channels, SIZE, SIZE).float()
patched_image = processor.patchify(image)
reconstructed_image = processor.unpatchify(patched_image)
for idx in range(0, SIZE // patch_size, patch_size):
for jdx in range(0, SIZE // patch_size, patch_size):
print(f"Patch ({idx}, {jdx}):")
sets_bef = set(image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
sets_aft = set(reconstructed_image[:, :, idx: idx + patch_size, jdx :jdx + patch_size].flatten().tolist())
print(sets_bef == sets_aft)
However this was not intended and what you pointed out is correct. This is unnessesary channel-wise shuffle that doesnt need to be here so ill remove this in the future
https://github.com/cloneofsimo/minRF/blob/72feb0c87d435e9f9d220f34f348ed66c0b6ccec/dit.py#L287-L288
Should this not be:
I would expect unpatchify( patchify( image ) ) == image but as is that is not the case.