facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
6.93k stars 1.17k forks source link

Code: Compatible to any channels for function patchify and unpatchify #192

Closed zhongruiHuangDMRI closed 4 months ago

zhongruiHuangDMRI commented 4 months ago

Dear author: Thanks for your great work. The only suggestion I found is that, for some cases (like medical image), we use 1 channel image (gray) instead of colorful image (RGB). Here are the revised code (patchify and unpatchify function) for any channel image : (Written by @CH2-Carbene)

def patchify_v2(self, imgs):
    """
    imgs: (N, self.img_channel, H, W)
    x: (N, L, patch_size**2 *self.img_channel)
    """
    p = self.patch_embed.patch_size[0]
    assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0

    h = w = imgs.shape[2] // p
    x = imgs.reshape(shape=(imgs.shape[0], self.img_channel, h, p, w, p))
    x = torch.einsum('nchpwq->nhwpqc', x)
    x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.img_channel))
    return x

def unpatchify_v2(self, x):
    """
    x: (N, L, patch_size**2 *self.img_channel)
    imgs: (N, self.img_channel, H, W)
    """
    p = self.patch_embed.patch_size[0]
    h = w = int(x.shape[1]**.5)
    assert h * w == x.shape[1]

    x = x.reshape(shape=(x.shape[0], h, w, p, p, self.img_channel))
    x = torch.einsum('nhwpqc->nchpwq', x)
    imgs = x.reshape(shape=(x.shape[0], self.img_channel, h * p, h * p))
    return imgs
hugoWR commented 1 month ago

Why was this closed ? It appears that patchify and unpatchify functions still hardcode RGB instead of using the number of input channel ?

daisukelab commented 1 month ago

FYI -- This is basically for an audio spectrogram while we also just have applied it to an image.

https://github.com/nttcslab/m2d/blob/4fd394607b7ab1403c3863bf69f942d1354bbae4/patch_m2d.diff#L1867-L1906