lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

Usage for image generation #18

Closed Baran-phys closed 3 years ago

Baran-phys commented 3 years ago

I was wondering how is it possible to use this instead of the old self-attention module in SAGAN?

lucidrains commented 3 years ago

@Hosein47 Hi Hosein! Images are actually also sequences, where the channels are the dimensions and number of pixels is the length of the sequence

You would end up doing something like this

from einops import rearrange # using einops

h, w = img.shape[2:]
seq = rearrange(img, 'b c h w -> b (h w) c')
attn_out = transformer(seq)
attended_img = rearrange(attn_out, 'b (h w) c -> b c h w', h = h, w = w)
lucidrains commented 3 years ago

@Hosein47 However, what I found works great for image self attention is actually linear attention, and I have a specially made attention module here https://github.com/lucidrains/linear-attention-transformer#images I've successfully used it for https://github.com/lucidrains/stylegan2-pytorch and https://github.com/lucidrains/lightweight-gan

lucidrains commented 3 years ago

@Hosein47 Another type of attention that is good for images to consider is Axial Attention https://github.com/lucidrains/axial-attention. Finally, there is a recent ICLR paper that combines both axial and linear attention into one https://github.com/lucidrains/global-self-attention-network

Baran-phys commented 3 years ago

Bu

@Hosein47 Hi Hosein! Images are actually also sequences, where the channels are the dimensions and number of pixels is the length of the sequence

You would end up doing something like this

from einops import rearrange # using einops

h, w = img.shape[2:]
seq = rearrange(img, 'b c h w -> b (h w) c')
attn_out = transformer(seq)
attended_img = rearrange(attn_out, 'b (h w) c -> b c h w', h = h, w = w)

But the self-attention in SAGAN is based on convolution layers. So, you are saying that using let's say GSA instead of self-attention(based on con2d) in SAGAN or BIggan-deep is more effective?

lucidrains commented 3 years ago

a 1x1 conv2d is equivalent to a linear, except that it acts on the first dimension rather than the last

SAGAN's self-attention architecture is quite dated. they only used one head, and they are also using the full-attention variant, so they couldn't scale to anything more than 32x32 feature maps. You can easily best SAGANs results these days

Baran-phys commented 3 years ago

I see. Thats true. Interesting. I just used the GSA as: if self.arch['attention'][self.arch['resolution'][index]]: print('Adding attention layer in G at resolution %d' % self.arch['resolution'][index]) self.blocks[-1] += [GSA(self.arch['out_channels'][index], norm_queries = True, batch_norm = False)] in ajbrock's BigGAN-deep, however I get the following error: "forward() takes 2 positional arguments but 3 were given"

Screenshot 2020-11-28 at 21 59 24

Where do you think the error is coming from? Tnx alot

lucidrains commented 3 years ago

@Hosein47 I don't know the BigGan code base, but from the error in the code, it is trying to pass GSA two arguments during forward, while I only accept one (the image feature map)

Baran-phys commented 3 years ago

@Hosein47 I don't know the BigGan code base, but from the error in the code, it is trying to pass GSA two arguments during forward, while I only accept one (the image feature map)

I think the problem is the class labels. I changed your code a little bit and used "y=None" in the GSA or linear attention inout. Tnx, It worked.

Baran-phys commented 3 years ago

@lucidrains I use both GSA and linear attention transformer in the BigGAN-deep model instead of the old self-attention block in ch = 64 for a conditional 256*768 image generation task, however, the result in terms of image quality with the exact same hyperparameters was awful. I mean it was faster with lower parameters (especially with GSA), but I got very bad results within a definite amount of iterations. Should I expect this or I am doing sth very wrong?

lucidrains commented 3 years ago

@Hosein47 I don't know the BigGAN architecture well enough to say for certain. You should expect linear attention to be worse than full-attention. However, it will allow you to process higher resolution feature maps that is too costly for full-attention. I would say, if your goal is to improve on BigGAN, keep the current attention block, but at higher resolutions, introduce linear attention

lucidrains commented 3 years ago

@Hosein47 This is kind of off topic for Routing Transformers anyhow, perhaps contact me through email?