facebookresearch / hiera

Hiera: A fast, powerful, and simple hierarchical vision transformer.
Apache License 2.0
717 stars 36 forks source link

How do we drop tokens? #5

Closed luke-mcdermott-mi closed 1 year ago

luke-mcdermott-mi commented 1 year ago

How exactly are we dropping tokens? In Applying MAE it says "we shift the mask units to the batch dimension to separate them for pooling (effectively treating each mask unit as an “image”)"

If we have a batch of 2 images [a,b], that look like: $a = \begin{bmatrix} 1 & M & 3\\ M & M & 6 \end{bmatrix}, b = \begin{bmatrix} A & M & M\\ D & E & M \end{bmatrix}$ where M's are masked tokens, then does the new image look like this: $a' = \begin{bmatrix} 1 & A & 3\\ D & E & 6 \end{bmatrix}, b' = \begin{bmatrix} M & M & M\\ M & M & M \end{bmatrix}$ and we drop the b' image? or does it look like this: $a = \begin{bmatrix} 1 & 3 & 6 \end{bmatrix}, b = \begin{bmatrix} A & D & E \end{bmatrix}, ....$ and there exists pictures c',d',etc. with each M token which is dropped.

Or am I looking in the completely wrong direction. I see that max pooling kernel = stride, so that it can skip masked layers. So the pooling is just iterating across the image and skips masked units? This reasoning makes sense; however, I am confused how this relates to the quote above "we shift the mask units to the batch dimension to separate them for pooling".

Thanks in advance

dbolya commented 1 year ago

The shift to batch (or separate and pad) trick is only necessary for the intermediate MViTv2 ablations we did for Table 1 of the paper (because of the kernel overlap). The final Hiera model actually doesn't use it at all, since as you said we can just skip masked units.

For MAE, we enforce that the same number of units are masked in each image. That way, if we mask as in your example every image in the batch will always be left with 3 units. Then, to answer your question (note the shift to batch trick is not implemented in this repo because it's not necessary for Hiera), let's say we have 4 images of w=96, h=64, with 3 channels.

Then our input tensor would look like:
input_image: shape = [4, 3, 64, 96]

Each token is 4x4 pixels, so once we tokenize the image, we're down to: tokenized_image: shape = [4, 3, 64, 96] -> tokenizer (patch embed) -> [4, 144, 16, 24]
Note that the tokenizer also ups the channel dim to 144 (e.g. for L models).

Then we extract the mask units, which are each 8x8 tokens: tokenized_image_mu: shape = [4, 144, (2, 8), (3, 8)] -> permute -> [4, 144, (2, 3), (8, 8)] -> reshape -> [4, 144, 6, 64] Here, each image contains 6 mask units in a 2x3 arrangement as in your example above, where each mask unit is 8x8 (64) tokens.

Now, we delete the same number of tokens from each image, so if the masking ratio is 50%, we would choose 3 from each of the 4 images to discard: masked_image_mu: shape = [4, 144, 6, 64] -> discard 3 mus from ea. image -> [4, 144, 3, 64]

Then, finally the shift to batch trick: just move that "3" dimension into the batch dimension. shifted_to_batch: shape = [4, 144, 3, 64] -> permute -> [(4, 3), 64, 144] -> reshape -> [12, 64, 144]

Then this is the familar [batch, tokens, embed_dim] shape that you can pass into any transformer. The pooling and window attn can be done on the "64" dimension (i.e., 8x8 -> 4x4 -> 2x2, etc.) and you can do 3x3 kernels like MViT if you pad it (which is why we also call it "separate and pad").

This is a sort of long winded explanation, but hopefully that made sense.

luke-mcdermott-mi commented 1 year ago

Thank you, yes that makes sense.