facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
934 stars 67 forks source link

Why do we need this code snippet for training MAE? #44

Open redagavin opened 3 weeks ago

redagavin commented 3 weeks ago

image Hi, could you please explain why do we need this code snippet when training MAE? Why is apply_patch from timm.py not enough? Thank you!

dbolya commented 3 weeks ago

MAE models are trained with global average pooling at the end instead of a class token. Since we're changing the size of each token by merging them together, we need to perform this global average pool with a weight proportional to the size of each token.

This is also what merge_wavg does when merging tokens together (just this time it's global).