berniwal / swin-transformer-pytorch

Implementation of the Swin Transformer in PyTorch.
https://arxiv.org/pdf/2103.14030.pdf
MIT License
794 stars 129 forks source link

Cyclic shift with masking #9

Closed Hayoung93 closed 3 years ago

Hayoung93 commented 3 years ago

Hello sir, I'm trying to understand "efficient batch computation" which the authors suggested. Probably because of my short knowledge, it was hard to get how it works. Your implementation really helped me for understanding its mechanism, thanks a lot!

Here's my question, it seems the masked area of q * k / sqrt(d) vanishes during the computation of self-attention. I'm not sure that I understood the code correctly, but is this originally intended in the paper? I'm wondering if each subwindow's self-attention might be computed before reversing. image

Apology if I misunderstood something, and thanks again!

berniwal commented 3 years ago

Thank you very much for your interest in the code.

If I understood your question correctly you mean why there are 0 values in the masked area later in the attention matrix resulting from q * k / sqrt(d)?

If you look at the lower right window in the resulting matrix after the cyclic shift, it consists of 4 subwindows that are not next to each other in the original image. In each window however we compute the self-attention between all values in the window (even between values across subwindows in this case). If we now also don't want the model to think that the subwindows are next to each other (and be able to pay attention across borders) we need to set the attention for those values (the masked values) to zero. Which is for example for the grey subwindow all values in the blue, yellow and green area. This way we only compute the self-attention within each subwindow as desired.

Hayoung93 commented 3 years ago

Thanks for your polite answer.

However, it seems my explanation was not sufficient (caused by my poor English skills). Let me use the right-bottom box that contains grey, green, blue, yellow, and alphabet 'A' for example; it seems (for me) that self-attention is computed only for the grey region, however, I'm curious about green, blue, and yellow region. Didn't they come from 3 corners of the original shifted window and should be used for computing self-attention (maybe through generating 4 different masks in this example box)?

Again, thank you!

berniwal commented 3 years ago

Yes you are right that they come from 3 corners of the original shifted window and we want to compute self-attention for them and then shift them back to their original location. And the code should also not do this only for the grey area, otherwise there would be a major bug in the code.

How I tried to solve this is by having two different masks defined. The upper_lower mask masks for the case of the window in the lower-left corner (containing just grey and 'C') the upper part from the lower part in the attention matrix. The left_right mask masks for the case of the window in the upper right corner (containing just grey and 'B') the left part from the right part in the attention matrix. The window in the lower-right corner is just applying both masks together therefore masking the left part from the right part as well as masking the upper part from the lower part. The masks are defined here:

Selection_400

Example: We have a window size of 7x7 and therefore shift the window by 3x3 pixels in the upper left direction. Therefore in the lower-left window, we have 4 rows of 'grey' pixels and 3 rows of 'green' pixels shifted across the border. The attention map is 49x49 as we compute the self-attention between all possible pixel combinations and 49 is the flattened window size (7x7). As we now want the upper four rows (grey, index [0-27]) not to attend to the pixels in the lower three rows (green, index [28-48]) and vice versa we mask for all rows from including index 28 the columns in the attention matrix having index strictly smaller than 28 and for all the rows up to index 27 (including) the columns having strictly larger index than 27. However the self-attention is still computed for all rows and therefore for both subwindows ('grey' and 'green'), the final resulting 49xD matrix (where D is the hidden dimension) is then shaped back into 7x7xD and reverse shifted to the original location. Similar reasoning for the left_right mask where we however need to reshape the values first to more easily create the desired mask.

I hope this makes sense and let me know if you have further questions.

Hayoung93 commented 3 years ago

This taught me a lot! All my questions are solved.

Thank you XD