yjh0410 / SAMI

Masked AutoEncoders leveraging Segment-Anything
12 stars 2 forks source link

About the Cross-Attention Decoder #1

Open yeyeyeping opened 6 months ago

yeyeyeping commented 6 months ago

Thank you for your implementation of SAMI's training code. It has been incredibly helpful for me!

However,I have a question regarding the pretraining process. In forward preprocess of the MaeDecoder, I noticed that the masked embeddings are reconstructed through self-attention, which is clever but seems to be inconsistent with the cross-attention described in the original paper.

I am confused about the Cross-Attention Decoder module. Could you please explain your understanding of the query, key, and value as mentioned in Section 3.2 on page 3? And why you implenment query as learnable mask_token?

yjh0410 commented 6 months ago

@yeyeyeping Thanks for your comments. The problem you pointed out is my mistake. The MaeDecoder in SAMI is supposed to use Cross Attention (mask tokens will be used as the query), not Self Attention, which is very different from the standard MAE. I will fix this bug as soon as possible.

However, this sentence "In the cross-attention decoder, queries come from masked tokens, and keys and values derive from both unmasked features from encoder and masked features." in the paper make me confused, “unmasked features from encoder” is clear, but I am not sure what are "masked features".

The idea of MIM is that we hope to leverage the high-level feature extracted from the unmasked patches to reconstruct the masked patches(MAE) or feature(SAMI), so it is reasonable that we use the mask tokens as the query and use the unmasked patches's feature as the key and value. In this way, we can import the features from unmasked patches into mask tokens through the cross attention operations.

Of course, the Self Attention can also achieve this goal, just as the Kaiming's MAE does. The reason why SAMI uses Cross Attention may be partly because they need unmasked patches' features to learn the feature of the SAM ViT-H, and partly because they have undergone ablation and found that it is better to do so under their design.

yeyeyeping commented 6 months ago

@yeyeyeping Thanks for your comments. The problem you pointed out is my mistake. The MaeDecoder in SAMI is supposed to use Cross Attention (mask tokens will be used as the query), not Self Attention, which is very different from the standard MAE. I will fix this bug as soon as possible.

However, this sentence "In the cross-attention decoder, queries come from masked tokens, and keys and values derive from both unmasked features from encoder and masked features." in the paper make me confused, “unmasked features from encoder” is clear, but I am not sure what are "masked features".

The idea of MIM is that we hope to leverage the high-level feature extracted from the unmasked patches to reconstruct the masked patches(MAE) or feature(SAMI), so it is reasonable that we use the mask tokens as the query and use the unmasked patches's feature as the key and value. In this way, we can import the features from unmasked patches into mask tokens through the cross attention operations.

Of course, the Self Attention can also achieve this goal, just as the Kaiming's MAE does. The reason why SAMI uses Cross Attention may be partly because they need unmasked patches' features to learn the feature of the SAM ViT-H, and partly because they have undergone ablation and found that it is better to do so under their design.

Your reply really helped to clarify my understanding. Thank you for your time and help.

yjh0410 commented 6 months ago

@yeyeyeping You are welcome.

By the way, I am still feeling strange about the Cross Attention used in official SAMI. As the described in the paper, the Q is masked token, and the K and V are from the unmasked features ofencoder and masked features. This description is too vague for me to understand why masked features are used, because in my understanding, this is the learning objective. Anyway, the official SAMI source code is not publicly available, so I currently prefer to refer to Kaiming's MAE Decoder based on Self Attention (Q=K=V=mask token + unmasked features from encoder).

zzy0428 commented 5 months ago

@yjh0410 Thanks for your comments. The problem you pointed out is my mistake. The MaeDecoder in SAMI is supposed to use Cross Attention (mask tokens will be used as the query), not Self Attention, which is very different from the standard MAE. I will fix this bug as soon as possible.

However, this sentence "In the cross-attention decoder, queries come from masked tokens, and keys and values derive from both unmasked features from encoder and masked features." in the paper make me confused, “unmasked features from encoder” is clear, but I am not sure what are "masked features".

The idea of MIM is that we hope to leverage the high-level feature extracted from the unmasked patches to reconstruct the masked patches(MAE) or feature(SAMI), so it is reasonable that we use the mask tokens as the query and use the unmasked patches's feature as the key and value. In this way, we can import the features from unmasked patches into mask tokens through the cross attention operations.

Of course, the Self Attention can also achieve this goal, just as the Kaiming's MAE does. The reason why SAMI uses Cross Attention may be partly because they need unmasked patches' features to learn the feature of the SAM ViT-H, and partly because they have undergone ablation and found that it is better to do so under their design.

As I understand it, the reason for using cross-attention is the unmasked features are well learned after training encoder and the goal here for decoder is only for reconstructing masked tokens. Thus, they want to set up separate sequences, one for the query, exclusively containing masked tokens without the necessity for unmasked tokens. https://github.com/yformer/EfficientSAM/issues/21#issuecomment-1856870412

zzy0428 commented 5 months ago

@yjh0410 Thanks for your comments. The problem you pointed out is my mistake. The MaeDecoder in SAMI is supposed to use Cross Attention (mask tokens will be used as the query), not Self Attention, which is very different from the standard MAE. I will fix this bug as soon as possible. However, this sentence "In the cross-attention decoder, queries come from masked tokens, and keys and values derive from both unmasked features from encoder and masked features." in the paper make me confused, “unmasked features from encoder” is clear, but I am not sure what are "masked features". The idea of MIM is that we hope to leverage the high-level feature extracted from the unmasked patches to reconstruct the masked patches(MAE) or feature(SAMI), so it is reasonable that we use the mask tokens as the query and use the unmasked patches's feature as the key and value. In this way, we can import the features from unmasked patches into mask tokens through the cross attention operations. Of course, the Self Attention can also achieve this goal, just as the Kaiming's MAE does. The reason why SAMI uses Cross Attention may be partly because they need unmasked patches' features to learn the feature of the SAM ViT-H, and partly because they have undergone ablation and found that it is better to do so under their design.

As I understand it, the reason for using cross-attention is the unmasked features are well learned after training encoder and the goal here for decoder is only for reconstructing masked tokens. Thus, they want to set up separate sequences, one for the query, exclusively containing masked tokens without the necessity for unmasked tokens. yformer/EfficientSAM#21 (comment)

Still kinda confused about 'Masked features'.