Open dvd42 opened 1 year ago
I had the same problem. You would have to change the get_2d_sincos_pos_embed so it can take a tuple instead of an int.
It's not too hard to do and it would look something like that:
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: tuple (H, W)
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
else:
grid_h = np.arange(grid_size[0], dtype=np.float32)
grid_w = np.arange(grid_size[1], dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
FYI -- https://github.com/nttcslab/msm-mae/blob/main/msm_mae/patch_msm_mae.diff#L82
We've been doing that for two years. Our patch might also show you some information.
How would one generate positional embeddings give a non-square number of patches?
Right now this cannot be done since a shape mismatch would occur when copying the embeddings
https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L71-L72
Thanks in advance!