heykeetae / Self-Attention-GAN

Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
2.51k stars 470 forks source link

Missing one 1x1 conv on output from attention layer? #50

Closed harrygcoppock closed 4 years ago

harrygcoppock commented 4 years ago

The paper seems to indicate that a 4th 1x1 conv is applied to the output of the attention layer? Please see difference in images below:

Screenshot 2020-07-06 at 11 52 16 Screenshot 2020-07-06 at 11 52 43
harrygcoppock commented 4 years ago

happy to implement the changes if this is a mistake.

ColorDiff commented 4 years ago

It is surely not the way it is described in the paper. The meaningless per-pixel masks in the README.md also indicate some inaccuracy in implementation. Implementing the attention layer under the assumption of a B x C x N shaped input under the utilization of 1d convolutions is probably the way to go. It is confusing that the authors speak of 1x1 convolutions when in fact they have 1-dimensional Convolutions with kernel size 1 inplace. This fact can be inferred from the fact that their convolution reduces the channel dimension while having a kernel size of 1, which means that it is in fact the amount of kernels used that reduced this dimension and not the valid padding, which could be possible if it was a 2d convolution. In the following you have my implementation of the self-attention layer assuming 3d inputs (#batch, #channel, #features) utilizing 1d Convolutions for channel size compression.

class SelfAttention(nn.Module):

    def __init__(self, in_channels: int, compression_factor: int = 8):
        super().__init__()
        assert (in_channels % compression_factor) == 0

        self.q = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1)  # f
        self.k = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1)  # g
        self.v = nn.Conv1d(in_channels=in_channels, out_channels=in_channels // compression_factor, kernel_size=1)  # h
        self.o = nn.Conv1d(in_channels=in_channels // compression_factor, out_channels=in_channels, kernel_size=1)  # v
        self.gamma = nn.Parameter(torch.zeros(1), requires_grad=True)

        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B x C x N)
            returns :
                out : self attention value + input feature
                attention: B x N x N
        """
        query = self.q(x).permute(0, 2, 1)  # f(x)T | B x N x C//k
        key = self.k(x)  # g(x) | B x C//k x N
        similarity = torch.bmm(query, key)  # batch matrix multiplication -> B x N x N
        attention = self.softmax(similarity)  # SoftMax applied over feature dimension N

        value = self.v(x)  # h(x) | B x C//k x N
        out = torch.bmm(value, attention)  # B x C//k x N
        out = self.o(out)  # B x C x N

        out = self.gamma * out + x
        return out, attention

Suggestions and corrections are welcome.

harrygcoppock commented 4 years ago

Thank you for your response. Yes this is the same as my implementation (using .view(batch, -1, W*H) to get feature vector. I was wondering if there was a particular reason the owner of this repo missed out the output (4th) conv?

ColorDiff commented 4 years ago

I don't know. Again, I feel like the formulation in the paper could be a little more clear, which possibly would have avoided the confusion. But the 4th conv was clearly visible in the paper. Maybe there was a pre-print/peer-reviewed version and now all you find is the newest, actually published version in which they added the 4th reprojection convolution after realizing that you can reduce the cannels without loosing model capacity noticable in performance.