Closed LinB203 closed 1 year ago
Yes, the reason is PyTorch DDP and JAX both do not support variant input length on different GPUs.
Yes, the reason is PyTorch DDP and JAX both do not support variant input length on different GPUs.
You are right, but taking the average value of the loss for all GPUs is enough to train. Why do you want the mask token to act as padding? Have you done any relevant experiments to prove that the gain in accuracy doesn't come from this? I'm worried that this would be unfair, since most mask-based work follows the encoder stage of not seeing the mask token.
In our original JAX implementation, it does not support variant sequence length because it uses static graph. And in the PyTorch version, we keep this to be consistent with the JAX implementation.
We do not have an ablation experiment for this, but from the literature, adding masked token typically does not improve the performance (Table 1(c) in MAE, for example). Also, during downstream inference such as linear probing, masked token is not used.
Thank you for your patience in replying. You are right, even if encoder sees the mask token, it is replaced with a fake cls token at decoder time, so I presume it has no effect.
In our original JAX implementation, it does not support variant sequence length because it uses static graph. And in the PyTorch version, we keep this to be consistent with the JAX implementation.
We do not have an ablation experiment for this, but from the literature, adding masked token typically does not improve the performance (Table 1(c) in MAE, for example). Also, during downstream inference such as linear probing, masked token is not used.
Wonderful work! If the mask ratio is 0.55 and the input has L tokens, then according to the code, there are 0.45L tokens that are visible, 0.55L tokens that are invisible, and 0.5*L tokens that are dropped in the invisible token. This ensures that each encoder input is L/2+1 (1 for cls token), so why does the encoder need a little bit of mask token? Is it because of letting the input length be fixed?