facebookresearch / ToMe

A method to increase the speed and lower the memory footprint of existing vision transformers.
Other
931 stars 67 forks source link

Training with merging #13

Closed WayneCV closed 1 year ago

WayneCV commented 1 year ago

Hi, I want to apply bipartite_soft_matching to my project. During training, you treat the token merging as a pooling operation. Can I directly use bipartite_soft_matching without any change during the training stage?

dbolya commented 1 year ago

Yes, you should be able to use it during training with minimal hyperparameter difference. It will act as a pooling operation. But if you're not using the patches, then make sure you call it like this:

merge, unmerge = bipartite_soft_matching(___, r=___)  # Replace with your features and r value
x = merge(x, mode="mean")  # Use amax if you want average pooling
yiren-jian commented 1 year ago

@dbolya Hi Daniel, Thanks for you great work! It seems that

merge, unmerge = bipartite_soft_matching(___, r=___)  # Replace with your features and r value
x = merge(x, mode="mean")  # Use amax if you want average pooling

is equal to

### from https://github.com/facebookresearch/ToMe/blob/main/tome/patch/timm.py
merge, unmerge = bipartite_soft_matching(___, r=___)
x, self._tomo_info['size'] = merge_wavg(merge, x,  self._tome_info['size'])

right?

Also, in the post above x = merge(x, mode="mean") # Use amax if you want average pooling: what is amax in your comment?

Thank you

dbolya commented 1 year ago

Hi @yiren-jian, the two are equivalent for one layer only, after which the second code block computes a weighted average while the first computes an unweighted average.

The weighted / unweighted distinction is important for using ToMe without training. At the start of the network, each token corresponds to one input patch, so they can all be equally considered when merging. However, in the next layer some tokens have already been merged. Thus, after the first layer each token corresponds to a different number of input patches.

Let's say you had two tokens, one that represents 10 input patches, and one that represents just 1. If you averaged those two tokens together, you'd assign 0.5 the first + 0.5 the second, but that means the 10 patches in the first token has the same effect as just a single patch from the second token. To fix this, we do a "weighted average" by default, giving 10/11 weight to the first token and 1/11 weight to the second token.

However, I found this not to be necessary if you train the model, so I omitted it in this code block.

Also, by "Use amax if you want average pooling", I misspoke, I meant "amax if you want max pooling". Providing "amax" will use max pooling instead of averaging.

yiren-jian commented 1 year ago

Thank you so much. Really appreciate your detailed explanations.