LTH14 / mage

A PyTorch implementation of MAGE: MAsked Generative Encoder to Unify Representation Learning and Image Synthesis
MIT License
529 stars 25 forks source link

Why does the encoder need a little bit of mask token? #13

Closed LinB203 closed 1 year ago

LinB203 commented 1 year ago

Wonderful work! If the mask ratio is 0.55 and the input has L tokens, then according to the code, there are 0.45L tokens that are visible, 0.55L tokens that are invisible, and 0.5*L tokens that are dropped in the invisible token. This ensures that each encoder input is L/2+1 (1 for cls token), so why does the encoder need a little bit of mask token? Is it because of letting the input length be fixed?

LTH14 commented 1 year ago

Yes, the reason is PyTorch DDP and JAX both do not support variant input length on different GPUs.

LinB203 commented 1 year ago

Yes, the reason is PyTorch DDP and JAX both do not support variant input length on different GPUs.

You are right, but taking the average value of the loss for all GPUs is enough to train. Why do you want the mask token to act as padding? Have you done any relevant experiments to prove that the gain in accuracy doesn't come from this? I'm worried that this would be unfair, since most mask-based work follows the encoder stage of not seeing the mask token.

LTH14 commented 1 year ago

In our original JAX implementation, it does not support variant sequence length because it uses static graph. And in the PyTorch version, we keep this to be consistent with the JAX implementation.

We do not have an ablation experiment for this, but from the literature, adding masked token typically does not improve the performance (Table 1(c) in MAE, for example). Also, during downstream inference such as linear probing, masked token is not used.

LinB203 commented 1 year ago

Thank you for your patience in replying. You are right, even if encoder sees the mask token, it is replaced with a fake cls token at decoder time, so I presume it has no effect.

In our original JAX implementation, it does not support variant sequence length because it uses static graph. And in the PyTorch version, we keep this to be consistent with the JAX implementation.

We do not have an ablation experiment for this, but from the literature, adding masked token typically does not improve the performance (Table 1(c) in MAE, for example). Also, during downstream inference such as linear probing, masked token is not used.