heykeetae / Self-Attention-GAN

Pytorch implementation of Self-Attention Generative Adversarial Networks (SAGAN)
2.52k stars 474 forks source link

about the gamma parameter #1

Open Entonytang opened 6 years ago

Entonytang commented 6 years ago

In your code, the shape of gamma [batchsize,1,1,1]. I think the shape should be [1]. Besides, the attention score you get seems to be different with Han's paper. Did you calculate the attention score using the same equation as eqn.(1) in paper.

heykeetae commented 6 years ago

Thank you very much for your comment! I think you are right about gamma. Let me make sure to correct it and update it. About the attention map, it has the dimension of batchsize x number_of_feature (o in paper, which I interpreted as the total pixel number). It is the same as batchsize x H x W. In the code, H =W (=f in code, perhaps is the source of confusion). Since each pixel owns an attention map, requiring the total required dimension of batchsize x f^2 x f x f. Sorry for the confusing notation. Please point out if there are still other mistakes.

Entonytang commented 6 years ago

Based on my understanding. Self attention operation should like this : why you choose this method to calculate attention scores.

class SelfAttention(nn.Module):

def __init__(self, in_channel):
    super().__init__()
    self.query = nn.Conv1d(in_channel, in_channel // 8, 1)
    self.key = nn.Conv1d(in_channel, in_channel // 8, 1)
    self.value = nn.Conv1d(in_channel, in_channel, 1)
    self.gamma = nn.Parameter(torch.tensor(0.0))

def forward(self, input):
    shape = input.shape
    flatten = input.view(shape[0], shape[1], -1)
    query = self.query(flatten).permute(0, 2, 1)
    key = self.key(flatten)
    value = self.value(flatten)
    query_key = torch.bmm(query, key)
    attn = F.softmax(query_key, 1)
    attn = torch.bmm(value, attn)
    attn = attn.view(*shape)
    out = self.gamma * attn + input

    return out
heykeetae commented 6 years ago

Great suggestion! I'll sleep on that. However, that way is similar to what i tried at first, where I realized it makes more sense to have each pixel look at the different location of the previous layer by having a different attention map, since there are n number of resulting features o_j.

Entonytang commented 6 years ago

And : f_ready = f_x.contiguous().view(b_size, -1, f_size ** 2, f_size, f_size).permute(0, 1, 2, 4, 3) . Why you choose to transpose f_ready and multiply f_ready with g_ready. (why you choose transpose here)

heykeetae commented 6 years ago

That part is to reflect f(x)^T * g(x) in the paper :)

Entonytang commented 6 years ago

This operation aim to get a scalar value vector^T *vector = scalar value ; but the transpose operation in your code doesn't have this affect. This is just my understanding.

heykeetae commented 6 years ago

That's a very good point. The calculation involves the depth of a feature map, so multiplication does not end up with a scaler (per pixel), but looking at the line, attn_dist = torch.mul(f_ready, g_ready).sum(dim=1).contiguous().view(-1, f_size ** 2), there is .sum(dim=1) following the multiplication, that sums up depth-wise, making it a scalar, per pixel.

leehomyc commented 6 years ago

If every pixel has its own attention map, the memory will be consumed quickly as the image size goes up. I agree with @Entonytang's interpretation.

heykeetae commented 6 years ago

@leehomyc I'm still not sure if having only one attention score map justifies it. Looking at the paper, Figs. 1 and 5 show attention results, where a 'particular' area takes hint from different region. In @Entonytang 's implementation, that sort of visualization is not possible.

hythbr commented 6 years ago

I think the attention score by @Entonytang is agree with Han's paper. But, based on my understanding, attn = torch.bmm(value, attn) should like this value = value.permute(0, 2, 1) attn = torch.bmm( attn, value) attn = attn.permute(0,2,1) What do you think? @Entonytang, @heykeetae

leehomyc commented 6 years ago

why permute @hythbr

hythbr commented 6 years ago

According to the Eqn.(2) in paper, I think the matrix-matrix product after permute may represent the meaning of the Eqn.. However, I am not sure if it is right. Please point out if there are some errors. @leehomyc

heykeetae commented 6 years ago

We have updated the whole self attention module. please check it out! memory problem solved, and we are convinced it should agree with the paper too.

Entonytang commented 6 years ago

In this implementation, can you get better performance than the previous method you use? and can you tell me the final gamma value you trained ?

heykeetae commented 6 years ago

The performance, in honesty, is not distinguishable by human eyes. We should try the IS or FID to quantify the performance. About the gamma, the intent of the original authors goes unclear, which keeps increasing (or decreasing) under this implementation. It does not seem to converge for now, but one can try longer training to find it out!

liangbh6 commented 6 years ago

@Entonytang @heykeetae Hi, I read the code and doubt that how gamma change during training. It is defined as self.gamma = nn.Parameter(torch.zeros(1)) in line39 of sagan_model.py

liangbh6 commented 6 years ago

Well, I have figured out that gamma is treated as a learnable parameter.

valillon commented 3 years ago

Related

thd-ux commented 4 weeks ago

How to update gamma, manually update myself?