lucidrains / linear-attention-transformer

Transformer based on a variant of attention that is linear complexity in respect to sequence length
MIT License
668 stars 64 forks source link

[Question] Merging with Trans-XL? #2

Open gaceladri opened 4 years ago

gaceladri commented 4 years ago

Do you think that any of that implementations would be compatible with Transformer-XL? Thanks!

lucidrains commented 4 years ago

@gaceladri I've been thinking about it! Actually, I think this linear attention can be used anywhere where you desire full-attention but cannot pay the price. I'm actually working on the best implementation of transformers-xl at the moment https://github.com/lucidrains/compressive-transformer-pytorch You can make your request there where you want linear attention applied!

lucidrains commented 4 years ago

@gaceladri if your question meant whether one can introduce recurrence to this like transformer-xl, that would be harder to do, because naturally these sequences are so long, to hold all the previous activations in memory would be infeasible at some point

lucidrains commented 4 years ago

@gaceladri but hmm, now you got me thinking. it may be possible to compress the hidden activations into something smaller and introduce recurrence on top of these LR transformers. Haha, if you can propose something that makes sense, I'll consider adding it!

gaceladri commented 4 years ago

You'r a fucking machine! Haha, I am working with an implementation of Transformer-XL with adaptive softmax and dynamic evaluation and I found this awesome repo when I was looking for your linformer. Just thinking and questioning... At this moment I don't know which one would be better... But... I read the deepmind paper and I don't think that compression would be the best way, since you need to add more computation to your model and I am working for the SustaiNLP workshop at EMNLP2020. I am sure that there is a better way than compression. But I am not any expert like you or any deepminder... :)

gaceladri commented 4 years ago

I read on an orthogonal way the paper but my intuition says that they got better results just by adding more computation in any or other way... Not read the full paper at all.

gaceladri commented 4 years ago

@lucidrains One thing that I would do is, first a bottleneck after the embedding like in the MobileBert paper and then add the linear attention to the Trans-XL. It would makes sense?

lucidrains commented 4 years ago

@gaceladri ohh, I believe that linear layer is basically embedding factorization from Albert? It's available as a feature in most of my repos as the emb_dim hyperparameter

lucidrains commented 4 years ago

@gaceladri so, linear attention actually did not yield very good results for me at a certain length. This repository was mainly to explore whether combining it with local inductive prior (local attention) would improve its performance. At sequence length of 4096, on enwik8, the causal linear attention alone fails to converge. However, I have had good result using linear attention in places where I cannot use full attention, https://github.com/lucidrains/stylegan2-pytorch#attention So my decision tree is now basically to use linear attention only where I cannot use full-attention, and perhaps combine it with local attention for some cheap, weaker, global attention.

lucidrains commented 4 years ago

@gaceladri Also, I am far from an expert, just find it easier to learn concepts if I build them :)

lucidrains commented 4 years ago

@gaceladri I find compressed transformer interesting, because it successfully combined a memory write mechanism, with a recurrent attention network.

After learning Nvidia used LSTM with neural turing machine-like memory write/read to learn the dynamics of Pacman from pixels, I've become a lot more interested in memory in general.

gaceladri commented 4 years ago

Well, thanks for your answer! I am not looking for longer sequences, just efficiency and some performance. Do you think that it can work with a good performance/efficiency trade-off in sequences around 512 or so? Just to clarify that the linear attention will work just translating the function:

def linear_attn(q, k, v, kv_mask = None, one_kv_head = False):
    if kv_mask is not None:
        mask_value = max_neg_value(q)
        mask = kv_mask[:, :, None] if one_kv_head else kv_mask[:, None, :, None]
        k = k.masked_fill_(~mask, mask_value)

    dim = q.shape[-1]
    (q, k) = map(lambda x: x * (dim ** -0.25), (q, k))

    q = q.softmax(dim=-1)
    k = k.softmax(dim=-2)

    context_einsum_eq = 'bhnd,bhne->bhde' if not one_kv_head else 'bnd,bne->bde'
    context = torch.einsum(context_einsum_eq, k, v)

    attn_einsum_eq = 'bhnd,bhde->bhne' if not one_kv_head else 'bhnd,bde->bhne'
    attn = torch.einsum(attn_einsum_eq, q, context)

    return attn.reshape(*q.shape)

Instead of the normal attention computation?

gaceladri commented 4 years ago

Sure interesting! But maybe not so much when you are looking for low computation models?

lucidrains commented 4 years ago

@gaceladri yes, actually linear attention worked very well for me at lengths below 2048!

lucidrains commented 4 years ago

@gaceladri ohh, I understand now, you are primarily interested in efficiency. Got it. Have you read the Performer paper from Deepmind yet? https://arxiv.org/pdf/2006.03555.pdf

gaceladri commented 4 years ago

I also have been looking for Product Key Memory layers but I am working in tf and the official implementation is on pytorch. The embeddingbag in tensorflow is not optimised so it is veeeery slow in tf compared to pytorch.

gaceladri commented 4 years ago

Sounds good! I am going to read it! Thanks!

lucidrains commented 4 years ago

@gaceladri PKM works great for me! My researcher friend @AranKomat actually is studying conditional computation and cued me in on that

lucidrains commented 4 years ago

@gaceladri So what you should know is that the auto-regressive flavor of linear attention actually incurs a much greater memory cost, but EPFL wrote up a CUDA kernel that alleviates that issue. https://github.com/idiap/fast-transformers I also have a non-CUDA solution, but it requires pairing it with local (QK)V attention.

Otherwise, for non-autoregressive, it is very efficient, and multiple people (and me) have found it worked for their problems. But it isn't as good as full attention. The only paper that claims it is as good is the Linformer paper, but they've only benchmarked it against Roberta at a length of 4096

lucidrains commented 4 years ago

@gaceladri I also offer the Linformer linear attention in this repository, so feel free to try it and let me know what you discover :)

gaceladri commented 4 years ago

I follow Aran on Twitter :P. The thing is that I would like to avoid ad-hoc CUDA kernels since the last instance that I would like is to deploy the model on a mobile device. "linear attention actually incurs a much greater memory cost" Could it be because it is doing a softmax over the hidden space? Maybe an adaptive softmax could alleviate that? https://arxiv.org/abs/1809.10853

lucidrains commented 4 years ago

only the auto-regressive (GPT-like) linear attention incurs the greater cost. if you are building BERT like models, there will be no cost

lucidrains commented 4 years ago

@gaceladri knowing you are most attentive to efficiency, the Deepmind paper is most relevant, because they claim you can take pre-trained models on full-attention, and fine-tune them into linear-attention models with little loss in accuracy

gaceladri commented 4 years ago

"They claim you can take pre-trained models on full-attention, and fine-tune them into linear-attention" Awesome! But I don't think that I could reproduce their result in a short time that I have at this moment... Great discussion by the way! I am going to sleep on it and I will let you know if I find the linear attention useful on my Trans-xl or not! :=) Have a good night! :+1:

lucidrains commented 4 years ago

@gaceladri ok! let me know if you need any modifications to this repository so you can use any of the functions that are not exposed! Good night!

gaceladri commented 4 years ago

Hi, Did you read this already? Multi-Head Attention:Collaborate Instead of Concatenat image How do you see that mixed with the linear attention function? The collaborative attention code

gaceladri commented 4 years ago

@gaceladri so, linear attention actually did not yield very good results for me at a certain length. This repository was mainly to explore whether combining it with local inductive prior (local attention) would improve its performance. At sequence length of 4096, on enwik8, the causal linear attention alone fails to converge. However, I have had good result using linear attention in places where I cannot use full attention, https://github.com/lucidrains/stylegan2-pytorch#attention So my decision tree is now basically to use linear attention only where I cannot use full-attention, and perhaps combine it with local attention for some cheap, weaker, global attention.

I have been reading the linear transformer and seeing the Yannic video about it now. I have a question that maybe I am missing something but, do you have normalized the linear attention? And because of this, maybe is it not converging on enwiki8?

lucidrains commented 4 years ago

@gaceladri yup! Aran actually sent me the head collab paper! The key line seems to be https://github.com/epfml/collaborative-attention/blob/53ca19deebf62581b412b557da3455974afc7549/src/collaborative_attention/collaborative_attention.py#L96

Yup, multiple different papers seem to approach the normalization in different ways. I went with what was proposed in the original efficient attention paper. https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/linear_attention_transformer.py#L228-L229 and https://github.com/lucidrains/linear-attention-transformer/blob/master/linear_attention_transformer/linear_attention_transformer.py#L263 (causal)

lucidrains commented 4 years ago

I'll give the Transformer as RNN paper's approach a try too, but I think they are all roughly the same. The speech synthesis results were not very encouraging. I don't think this approach can win against full attention, but can probably serve as a weaker global attention to integrate local attention results in later layers. (what this repo tries to do)

lucidrains commented 4 years ago

@gaceladri I think the collab head attention is actually two ideas in one. There are already papers showing you can get away with one set of key/value heads (here they do one set of queries). But the 'collab' aspect of mixing the heads seems reminiscent of the Talking Heads paper https://arxiv.org/abs/2003.02436 from Shazeer

gaceladri commented 4 years ago

@lucidrains I am a bit confused with the dimensions since I am working with tf and I am aware that pytorch treats the dimensions differently.

At this line the dimensions respectively to q, k, v are?:

q = [batch, from, heads, size_per_head] k = [batch, to, heads, size_per_head] v = [batch, to, heads, size_per_head]

or

q = [batch, heads, from, size_per_head] k = [batch, heads, to, size_per_head] v = [batch, heads, to, size_per_head]

lucidrains commented 4 years ago

@gaceladri the latter! if one_kv_head = True then

q = [batch, heads, from, size_per_head]
k = [batch, to, size_per_head]
v = [batch, to, size_per_head]
lucidrains commented 4 years ago

@gaceladri what are you building? lol

AranKomat commented 4 years ago

@lucidrains @gaceladri I'd like to note that collab head was not evaluated on other important datasets such as Webtext (both more diverse and longer than the ones tested), so I'm not sure if it works on Webtext (or Wikitext-103 word-level) or not. As stated in this tweet of mine, the model with less budget on self-attention (h=2) performs on par with the baseline on some datasets, while it does not on Webtext. I think it is always important to test any model on either Wikitext-103 or Webtext, as other datasets may not use the self-attention module as much.

gaceladri commented 4 years ago

@AranKomat Thanks! What I am looking for is some trade-off between accuracy (ppl) and efficiency. Since what I would like more is scalability. Thanks for the point. My idea was to train my model with h=4. Just with intuition nothing theoretical.

@lucidrains I think that I had to have some bug. The memory consumption are equal with size length 128 with the einsums in the attention coming from mobilebert than the linear_trans that I implemented... I think that sz 128 should be negligible but the first try is something suspicious. Tomorrow I will check what is happening... I am also thinking about throw everything away and start from scratch with pytorch... Tomorrow is going to be an awesome day... :dancers: :smile:

Edit: Aran if it is on your road map to check this let me know please. Thanks!

AranKomat commented 4 years ago

By "this," do you mean the scalability of h=4 in the context of trade-off between ppl and computes? From my experience, on the datasets like Webtext, fixing h wasn't worth the saved computes, since the small h becomes more and more of a bottleneck to the performance if you increase d_model, d_ff or any other hyperparameter budget.

gaceladri commented 4 years ago

@AranKomat I expressed myself badly, I wanted to say that I am not going to scale to big computes... What I want to is to have the better trade-off on a GTX1070TI. So I am not going to scale too much but I want the better performance maybe sacrificing some ppl but gaining "scalability" (ridiculous) on my budged size. Does it make sense? So maybe I dont need to more bigger than h=4 with my #params . I think.

AranKomat commented 4 years ago

Makes sense now :)

gaceladri commented 4 years ago

@lucidrains I am struggling to understand the differences between your masking implementation and the masking mechanism in the normal attention in transformers library.

The misleading point is that when I try to implement your masking mechanism with the tuple I get an slicing error here. And when I try to take the masking coming from the native transformers library I get an error with the ~ inverse that tolds me that this way in Pytorch only works with booleans and integers so I assume that I am doing something wrong.

Could you enlighten me a little bit or give me some insight into what I might be doing wrong? The other point that I am a bit confused is that you do the masking just at the "K" while they do it to the whole score I assume that this difference is because the particular way of doing the attention.

Again, could you enlighten me a little bit on how to integrate your linear-att in "transformers"?

Thanks a lot!

lucidrains commented 4 years ago

@gaceladri could you show me your code where you get the slicing error?

the inverse error is due to the fact that in the transformer library, they use floating point 0s or 1s to denote masking, while I use booleans. If you simply take their masking and do a .bool(), it should work

yup, because linear attention does the dot product between keys and values along the sequence dimension, we don't need to do any masking for queries

gaceladri commented 4 years ago

@lucidrains Thanks a lot for your answer! Your time is very valuable. I have fixed it but not tested yet, for this my delay in the answer.

Do you know what is the logic to do this masking? -1000.0 instead of a simple 0 and 1 like "always"?

By the moment my code is just the transformer mobilebert with your linearformer and the rnn former from the authors. I have added a linear dense with einsum to optimize the performance that worked well in TF but I have not tested in Pytorch.

NOT TESTED! some bug surely.

class FastLinear3d(nn.Module):
    def __init__(self,
                 input_tensor,
                 num_attention_heads,
                 size_per_head):
        super(FastLinear3d, self).__init__()

        _, _, last_dim = input_tensor.shape
        # nn.Parameter is a special kind of Tensor, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters require gradients by default.
        self.weight = nn.Parameter(torch.Tensor(last_dim, num_attention_heads, size_per_head))
        self.bias = nn.Parameter(torch.Tensor(num_attention_heads, size_per_head))

        # Not a very smart way to initialize weights
        self.weight.data.xqavier_uniform_()
        self.bias.data.zeros_()

    def forward(self, x, activation=None):
        # See the autograd section for explanation of what happens here.
        ret = torch.einsum("abc,cde->abde", x, self.weight)
        ret += self.bias
        if activation is not None:
            return activation(ret)
        else:
            return ret

    def extra_repr(self):
        # (Optional)Set the extra information about this module. You can test
        # it by printing an object of this class.
        return 'input_features={}, output_features={}, bias={}'.format(
            self.input_features, self.output_features, self.bias is not None)

I will continue to implement the adaptive softmax (makes sense to you to add this to the architecture or does not makes sense due to some collision with something?) I will keep adding the dynamic eval if it does not have conflict with the actual architecture that is going on. Since I am working over the mobilebert that does distillation I think that I have to go with caution over the optimizer.

Would you like to prepare something for the SustaiNLP 2020 conference?