facebookresearch / mae

PyTorch implementation of MAE https//arxiv.org/abs/2111.06377
Other
7.2k stars 1.2k forks source link

Shouldn't the patch embeddings be trained only on the patches that survived masking? (Rather than the original image) #145

Closed Eduard6421 closed 1 year ago

Eduard6421 commented 1 year ago

In the code it seems that the patch embeddings are trained on the original image rather than learning only on the patches that survive the masking process. Does this mean the implementation does not follow the paper? From the paper:

"MAE encoder. Our encoder is a ViT [16] but applied only on visible, unmasked patches. Just as in a standard ViT, our encoder embeds patches by a linear projection with added positional embeddings, and then processes the resulting set via a series of Transformer blocks. However, our encoder only operates on a small subset (e.g., 25%) of the full set. Masked patches are removed; no mask tokens are used. This allows us to train very large encoders with only a frac- tion of compute and memory. The full set is handled by a lightweight decoder, described next."

From my interpretation it seems as if the embeddings should also be run only on the small subset of patches.

MaxChu719 commented 1 year ago

Yes, I got the same question. Have you found out the answer?

daisukelab commented 1 year ago

@Eduard6421 @MaxChu719 Hi. Just out of curiosity, would it be possible to explain your question a bit more (hopefully with the code you think the problem)?

I interpreted your question, "trained on the original image rather than learning only on the patches that survive the masking process," which means the loss is calculated over all the patches.

For me, it looks like the loss is calculated over the masked patches only. The following codes leave only the visible patches for the encoder:

https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L139-L140

And the loss is calculated on the masked patches only.

https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/models_mae.py#L213

Eduard6421 commented 1 year ago

@daisukelab @MaxChu719 Absolutely! Thank you for your response.

My concert is not related to the loss function employed in the model but rather the data that it is fed. The convolutional layer in the patch embedding receives the unmasked image. The embeddings are afterwards masked before being fed to the transformer. My question is the following: Was the patch embedding supposed to receive the original image? Shouldn't the convolutional layer run only on unmasked parts of the image as well?

As I see it the fact that the masking is applied after the patch embedding layer and not before it will allow the convolutional layer to generate embeddings which contain global information. Is this the reasoning behind this decision?

daisukelab commented 1 year ago

@Eduard6421 Hi, thanks for explaining the detail.

I think this is because it doesn't anyway affect loss calculations even if embedding operation by a Conv layer is done before masking. Therefore, the code doesn't need to be very strict to embed the masked portions of the original image into patches but can follow the vanilla ViT:

https://github.com/rwightman/pytorch-image-models/blob/e9aac412de82310e6905992e802b1ee4dc52b5d1/timm/models/vision_transformer.py#L382-L390

As I see it the fact that the masking is applied after the patch embedding layer and not before it will allow the convolutional layer to generate embeddings which contain global information.

A Conv layer is a local operation. As described in the MAE, "Just as in a standard ViT, our encoder embeds patches by a linear projection with ..." it is a clever implementation using a Conv layer as a 2D projection to patchfy and embed the input image into patches.

https://github.com/rwightman/pytorch-image-models/blob/e9aac412de82310e6905992e802b1ee4dc52b5d1/timm/layers/patch_embed.py#L46

Then, I think you don't have to concern about the contamination of global information. :)

Eduard6421 commented 1 year ago

@daisukelab Thank you very much for taking from your personal time to respond to my question. You have given me a comprehensive answer and it is now clear to me as to why the model is not contaminated by global information.

Kind regards, Eduard