facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
6.93k stars 1.17k forks source link

Non-square number of patches #151

Open dvd42 opened 1 year ago

dvd42 commented 1 year ago

How would one generate positional embeddings give a non-square number of patches?

Right now this cannot be done since a shape mismatch would occur when copying the embeddings

https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L71-L72

Thanks in advance!

hugoWR commented 12 hours ago

I had the same problem. You would have to change the get_2d_sincos_pos_embed so it can take a tuple instead of an int.

It's not too hard to do and it would look something like that:

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """
    grid_size: tuple (H, W)
    return:
    pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
    """
    if isinstance(grid_size, int):
        grid_h = np.arange(grid_size, dtype=np.float32)
        grid_w = np.arange(grid_size, dtype=np.float32)
    else:
        grid_h = np.arange(grid_size[0], dtype=np.float32)
        grid_w = np.arange(grid_size[1], dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed
daisukelab commented 11 hours ago

FYI -- https://github.com/nttcslab/msm-mae/blob/main/msm_mae/patch_msm_mae.diff#L82

We've been doing that for two years. Our patch might also show you some information.