Closed WayneCV closed 1 year ago
Yes, you should be able to use it during training with minimal hyperparameter difference. It will act as a pooling operation. But if you're not using the patches, then make sure you call it like this:
merge, unmerge = bipartite_soft_matching(___, r=___) # Replace with your features and r value
x = merge(x, mode="mean") # Use amax if you want average pooling
@dbolya Hi Daniel, Thanks for you great work! It seems that
merge, unmerge = bipartite_soft_matching(___, r=___) # Replace with your features and r value
x = merge(x, mode="mean") # Use amax if you want average pooling
is equal to
### from https://github.com/facebookresearch/ToMe/blob/main/tome/patch/timm.py
merge, unmerge = bipartite_soft_matching(___, r=___)
x, self._tomo_info['size'] = merge_wavg(merge, x, self._tome_info['size'])
right?
Also, in the post above x = merge(x, mode="mean") # Use amax if you want average pooling
: what is amax
in your comment?
Thank you
Hi @yiren-jian, the two are equivalent for one layer only, after which the second code block computes a weighted average while the first computes an unweighted average.
The weighted / unweighted distinction is important for using ToMe without training. At the start of the network, each token corresponds to one input patch, so they can all be equally considered when merging. However, in the next layer some tokens have already been merged. Thus, after the first layer each token corresponds to a different number of input patches.
Let's say you had two tokens, one that represents 10 input patches, and one that represents just 1. If you averaged those two tokens together, you'd assign 0.5 the first + 0.5 the second, but that means the 10 patches in the first token has the same effect as just a single patch from the second token. To fix this, we do a "weighted average" by default, giving 10/11 weight to the first token and 1/11 weight to the second token.
However, I found this not to be necessary if you train the model, so I omitted it in this code block.
Also, by "Use amax if you want average pooling", I misspoke, I meant "amax if you want max pooling". Providing "amax" will use max pooling instead of averaging.
Thank you so much. Really appreciate your detailed explanations.
Hi, I want to apply
bipartite_soft_matching
to my project. During training, you treat the token merging as a pooling operation. Can I directly usebipartite_soft_matching
without any change during the training stage?