Closed Hayoung93 closed 3 years ago
Thanks again for reporting potential errors in the code.
I agree with you on how we should compute the rolling and that rolling over the patches results in a wrong feature map. I think however that the code is already doing this, as in the WindowAttention we get as input the output of the PatchMerging module which is a feature map of (B x H x W x D) where B (batch size), H (height), W (width), D (hidden dimension). We then first compute the cylic_shift of this feature map (where we roll along the H, W dimensions) and only later in
q, k, v = map(lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d', h=h, w_h=self.window_size, w_w=self.window_size), qkv)
we create the windows and are moving to a (B x Heads x Windows x Elements in Window x Head Dimension) setting. For the reverse shift the same as we only shift back after:
out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)', h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
which then again moves from (B x Heads x Windows x Elements in Window x Head Dimension) to the (B x H x W x D) dimension and therefore should operate on the features again. It is therefore the same as your A_rearranged example but with an additional dimension for the feature dimension.
Thanks for the perfect reply. I really should study harder.
Thank you! XD
plus) I think I was confused between dividing patch (4x4 initially) and attention windowing (7x7). Sorry for bothering you.
Hello, sir. A question popped up again, unfortunately.
I've followed your shifting code, and it seems to have a difference with (my comprehension of) the paper. I understood the behavior of the original paper's window shifting as a black arrow in the image below (self-attention is calculated with elements inside of bold lines). The left red arrow points to the result of patch-wise rolling and the right red arrow points results of rolling the entire feature map.
In my opinion, self-attention should be computed according to the right-top figure, therefore, boxes of right-bottom should be used (green dot-line separates subwindows) which each region in the right-top figure preserves.
Please let me know if I misunderstood your code or something in the paper. Thanks a lot!
Additionally, this is how I mimicked your code:
where
A
can be considered as a 4x4 feature map (though element order is not matched with image above),A_patched
is a divided version ofA
, andA_patched_rolled
is patch-wise shifted version ofA_patched
, followingtorch.roll(x, shifts=(self.displacement, self.displacement), dims=(1, 2))
in your code.A_rearranged
is rearranged to match the image above.<---A_patched<---A_patched_rolled