GuoLanqing / ShadowFormer

ShadowFormer (AAAI2023), Pytorch implementation
MIT License
129 stars 17 forks source link

Question about patch-wise correlation map (sigma) in "Shadow-Interaction Attetnion" #26

Closed hyunW3 closed 1 year ago

hyunW3 commented 1 year ago

Nice work to solve shadow degradation problem

I look around the code to understand how patch-wise correlation map $\Sigma$ is obtained As far as i investigated, in the SIA($\boldsymbol(X), \Sigma$) is obtained by attention operation * [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] (eq 8 in the paper) https://github.com/GuoLanqing/ShadowFormer/blob/a55cf042a1f9371d0e56a1299c245604d6663b6f/model.py#L426C24-L426C33 The code reference above (in the class WindowAttention) is about operating attention operation. However, i cannot find where [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] is operated and $\Sigma$ is obtained.

Can you explain more detail about patch-wise correlation map and where the rest of the SIA operation ( [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] ) is done? Thank you

GuoLanqing commented 1 year ago

You can refer to Lines 398-405 in model.py.

hyunW3 commented 1 year ago

Thank you, To check my understanding is right, i have some question.

  1. variable 'mm' (at code mm = torch.unsqueeze(mm, dim=1)) stands for [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] (in eq (8)) ?
  2. The XOR operations is done by torch.where operation before and after matrix multiplication? Could you explain how the below code works as XOR operation?
        xm = torch.where(xm < 0.1, one, one*2)
        mm = xm @ xm.transpose(-2, -1)
        one = torch.ones_like(mm)
        mm = torch.where(mm==2, one, one*0.2)
  3. $\sigma$ which is the weight of shadodw-shadow and nonshadow-nonshadow pairs is 0.8 (in eq (8))??
    def forward(self, x, xm, attn_kv=None, mask=None):
        B_, N, C = x.shape
        # x = self.se_layer(x)
        one = torch.ones_like(xm)
        zero = torch.zeros_like(xm)
        xm = torch.where(xm < 0.1, one, one*2)
        mm = xm @ xm.transpose(-2, -1)
        one = torch.ones_like(mm)
        mm = torch.where(mm==2, one, one*0.2)
        mm = torch.unsqueeze(mm, dim=1)
        q, k, v = self.qkv(x, attn_kv)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1)) * mm
BlackJoke76 commented 1 year ago

after xm = torch.where(xm < 0.1, one, one*2), the value in xm are 1 and 2, and after xm @ xm.transpose(-2, -1), you can hava a variable 'mm' which just has three value 1, 2, 4。 And you can see 1 1 = 1, 2 2 = 4, 1 * 2 = 2, this can be seen as XOR operation。

hyunW3 commented 1 year ago

@BlackJoke76 Thank you.

I'm confused because the output of "mm" consists of 1 and 0.2, which is different from what i expect (1 and 0). The value 0.2 comes from [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] where $\sigma$ is 0.8. The operation of mm = torch.where(mm==2, one, one*0.2) is combination of XOR and [ $\sigma$ $\Sigma$ + (1- $\sigma$) 1] operations.

Thank you for helping me understand