Open lmisssunl opened 3 years ago
Yep, also found those discrepancies! This could solve it. Also make sure softmax iterates over keys for a given query (dim=1).
class Self_Attention(nn.Module):
def __init__(self, inChannels, k=8):
super(Self_Attention, self).__init__()
embedding_channels = inChannels // k # C_bar
self.key = nn.Conv2d(inChannels, embedding_channels, 1)
self.query = nn.Conv2d(inChannels, embedding_channels, 1)
self.value = nn.Conv2d(inChannels, embedding_channels, 1)
self.self_att = nn.Conv2d(embedding_channels, inChannels, 1)
self.gamma = nn.Parameter(torch.tensor(0.0))
self.softmax = nn.Softmax(dim=1)
def forward(self,x):
"""
inputs:
x: input feature map [Batch, Channel, Height, Width]
returns:
out: self attention value + input feature
attention: [Batch, Channel, Height, Width]
"""
batchsize, C, H, W = x.size()
N = H * W # Number of features
f_x = self.key(x).view(batchsize, -1, N) # Keys [B, C_bar, N]
g_x = self.query(x).view(batchsize, -1, N) # Queries [B, C_bar, N]
h_x = self.value(x).view(batchsize, -1, N) # Values [B, C_bar, N]
s = torch.bmm(f_x.permute(0,2,1), g_x) # Scores [B, N, N]
beta = self.softmax(s) # Attention Map [B, N, N]
v = torch.bmm(h_x, beta) # Value x Softmax [B, C_bar, N]
v = v.view(batchsize, -1, H, W) # Recover input shape [B, C_bar, H, W]
o = self.self_att(v) # Self-Attention output [B, C, H, W]
y = self.gamma * o + x # Learnable gamma + residual
return y, o
Hello, first of all thank you for your code, let me have a deeper study of SAGAN, but after reading your code I have the following questions: