facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
970 stars 69 forks source link

Why do we need this code snippet for training MAE? #44

Open redagavin opened 2 months ago

redagavin commented 2 months ago

image Hi, could you please explain why do we need this code snippet when training MAE? Why is apply_patch from timm.py not enough? Thank you!

dbolya commented 2 months ago

MAE models are trained with global average pooling at the end instead of a class token. Since we're changing the size of each token by merging them together, we need to perform this global average pool with a weight proportional to the size of each token.

This is also what merge_wavg does when merging tokens together (just this time it's global).