Open redagavin opened 2 months ago
MAE models are trained with global average pooling at the end instead of a class token. Since we're changing the size of each token by merging them together, we need to perform this global average pool with a weight proportional to the size of each token.
This is also what merge_wavg
does when merging tokens together (just this time it's global).
Hi, could you please explain why do we need this code snippet when training MAE? Why is apply_patch from timm.py not enough? Thank you!