heykeetae / Self-Attention-GAN

Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
2.52k stars 476 forks source link

The code is different from the original paper #54

Open lmisssunl opened 3 years ago

lmisssunl commented 3 years ago

Hello, first of all thank you for your code, let me have a deeper study of SAGAN, but after reading your code I have the following questions:

  1. In the original text, the relationship between β j, i and h(x) is N∑(i=1) βj, ih(xi), which is not the result of multiplying the torch.bmm matrix in your code
  2. The third 1*1 convolution in the article (i.e. h(x)), the dimension obtained is not C, but C//8
  3. The last layer in the original text has a 1 1 convolution, but it is not reflected in your code (maybe because of the second item I just mentioned, you changed the output C//8 to C, the last layer of 1 1 convolution is not ignored)
valillon commented 3 years ago

Yep, also found those discrepancies! This could solve it. Also make sure softmax iterates over keys for a given query (dim=1).

class Self_Attention(nn.Module):
    def __init__(self, inChannels, k=8):
        super(Self_Attention, self).__init__()
        embedding_channels = inChannels // k  # C_bar
        self.key      = nn.Conv2d(inChannels, embedding_channels, 1)
        self.query    = nn.Conv2d(inChannels, embedding_channels, 1)
        self.value    = nn.Conv2d(inChannels, embedding_channels, 1)
        self.self_att = nn.Conv2d(embedding_channels, inChannels, 1)
        self.gamma    = nn.Parameter(torch.tensor(0.0))
        self.softmax  = nn.Softmax(dim=1)

    def forward(self,x):
        """
            inputs:
                x: input feature map [Batch, Channel, Height, Width]
            returns:
                out: self attention value + input feature
                attention: [Batch, Channel, Height, Width]
        """
        batchsize, C, H, W = x.size()
        N = H * W                                       # Number of features
        f_x = self.key(x).view(batchsize,   -1, N)      # Keys                  [B, C_bar, N]
        g_x = self.query(x).view(batchsize, -1, N)      # Queries               [B, C_bar, N]
        h_x = self.value(x).view(batchsize, -1, N)      # Values                [B, C_bar, N]

        s =  torch.bmm(f_x.permute(0,2,1), g_x)         # Scores                [B, N, N]
        beta = self.softmax(s)                          # Attention Map         [B, N, N]

        v = torch.bmm(h_x, beta)                        # Value x Softmax       [B, C_bar, N]
        v = v.view(batchsize, -1, H, W)                 # Recover input shape   [B, C_bar, H, W]
        o = self.self_att(v)                            # Self-Attention output [B, C, H, W]

        y = self.gamma * o + x                          # Learnable gamma + residual
        return y, o
valillon commented 3 years ago

Apparently, as mentioned here, max pooling inside the attention layer is just motivated by design-wise to save computation/memory overhead. This should close the issue.