lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.63k stars 395 forks source link

NTK-aware Scaled RoPE #171

Closed Jingyu-Fan closed 1 year ago

Jingyu-Fan commented 1 year ago

Hi @lucidrain, wondering whether you are interested in incorporating the scaled RoPE to enable models to have extended context size without any fine-tuning.

The first discussion related was here, where bloc97 found that by simply changing 3 lines of the RoPE code, the LLaMA 7b model can have 8k+ context size without any fine-tuning and performance loss! Related code can be found at HF and colab.

Furthermore, the author of the RoPE paper(Jianlin Su) became interested and posted 2 blogs on this( blog1 and blog2, these 2 blogs are in Chinese), where he found the performance can be further improved by: 1) log n scaling of the attention (see here, equations 4 and 5); 2) mixed decimal (see blog2, equation 9)

Curious about your thoughts?

lucidrains commented 1 year ago

@Jingyu-Fan hey Jingyu, i believe i have that too as rotary_interpolation_factor https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L901 you would simply set this to be higher than 1. on a trained model and fine tune on longer context

lucidrains commented 1 year ago

so i have been following Jianlin Su's blog off and on. the log n scaling was notably used in Omegafold (a LLM approach to protein folding). however, i gave it a try on some contrived experiments and did not see much an improvement. willing to change my mind if you have experiments to share (and i can replicate)

  1. i haven't seen this yet and will read it later today! can you tell me the gist of it?
lucidrains commented 1 year ago

for point 1, i want to make sure this is absolutely necessary before adding it, since calculating sequence length for AR models (as well as variable lengthed bidirectional models) introduces complexity

lucidrains commented 1 year ago

@Jingyu-Fan another way for 1. is i could check in all the necessary code for this in a separate branch with a drafted pull request, and once a few people share their positive experimental results, i can merge it? writing code isn't that hard for me, but running experiments is time consuming and usually not a good use of my time given so many negative results

lucidrains commented 1 year ago

@Jingyu-Fan oh i misread the original post. w/o fine tuning is news to me, let me take a closer look today

lucidrains commented 1 year ago

ok, i really like it, adding it now

lucidrains commented 1 year ago

ok done

i'll read what improvements Jianlin Su came up with in his blog post later today! have you tried it?

lucidrains commented 1 year ago

also, if you know of a better way to cite reddit posts, let me know lol

Jingyu-Fan commented 1 year ago

WOW, this is too impressive about how responsive you are :) I will try and let you know how the experiment goes. Also, trying to find a way to cite Reddit post lol

Jingyu-Fan commented 1 year ago

@lucidrains very preliminary experiments tell me that it is actually useful in my setting! Without fine-tuning, the model works better than before in extended context size. However, the effective context size seems to be shorter than expected. Wondering whether it is due to that only a partial of dims(32/64) were using the rotary embedding by default. Will test to see.

lucidrains commented 1 year ago

@Jingyu-Fan hey, glad to hear! also got your email; i'm interested in the same topics as you are

if you look at the charts in the reddit post, there seems to be limit to the extrapolation. i also went through Jianlin Su's blog posts, but the improvements seem very incremental

i think your best bet is to carry out some fine tuning at this point

lucidrains commented 1 year ago

@Jingyu-Fan so this isn't followed up in the literature much, but some of us in the open source community had a lot of luck using dynamic positional bias for length extrapolation. however, in another project, even that has limits at around 4x the sequence length being trained

if you are training models from scratch, you should give it a try

Jingyu-Fan commented 1 year ago

@lucidrains Thanks for the suggestions!! I tried dynamic positional bias before and indeed saw some improvements. Since I expect over 10x the training length, incorporating longer context in the model training seems to be more hopeful...

For the continuation of the discussion, Jianlin Su just posted a new blog on this today(link). By setting a window size w and any relative distances in RoPE larger than w to w, he was able to almost eliminate the performance loss in length extrapolation. However, that seems to come with incompatibility with flash attention due to operations on the attention matrix... His code is on github. (link) I will take a close look later.

lucidrains commented 1 year ago

@Jingyu-Fan yes indeed the sequences you are working on are very long

i would recommend taking a look at this new work out of stanford https://arxiv.org/abs/2306.15794 tldr: there is room to capture long distance interactions with subquadratic methods, even if they are less expressive. supplemented with some sparse flash attention, it may go a long way

lucidrains commented 1 year ago

@lucidrains Thanks for the suggestions!! I tried dynamic positional bias before and indeed saw some improvements. Since I expect over 10x the training length, incorporating longer context in the model training seems to be more hopeful...

For the continuation of the discussion, Jianlin Su just posted a new blog on this today(link). By setting a window size w and any relative distances in RoPE larger than w to w, he was able to almost eliminate the performance loss in length extrapolation. However, that seems to come with incompatibility with flash attention due to operations on the attention matrix... His code is on github. (link) I will take a close look later.

yea, i can take a look at rerope later

but i can safely assure you that one probably cannot hope to extrapolate to unbounded lengths without some amount of fine tuning. not the amount that you are looking for, imo

Jingyu-Fan commented 1 year ago

yeah, I agree. thanks for the advice!

lucidrains commented 1 year ago

@Jingyu-Fan also, tell your wife "go blue" :smile:

Jingyu-Fan commented 1 year ago

Thx! We actually have a 'go blue' stuffed pillow in the car lol

bojone commented 1 year ago

so i have been following Jianlin Su's blog off and on. the log n scaling was notably used in Omegafold (a LLM approach to protein folding). however, i gave it a try on some contrived experiments and did not see much an improvement. willing to change my mind if you have experiments to share (and i can replicate)

  1. i haven't seen this yet and will read it later today! can you tell me the gist of it?

Logn scaling will hardly change the training effect, but it will alter the extrapolation effect of length, both ntk-rope and rerope will have an impact.

bojone commented 1 year ago

but i can safely assure you that one probably cannot hope to extrapolate to unbounded lengths without some amount of fine tuning. not the amount that you are looking for, imo

Based on the current results, I agree with this. But out of anthropomorphic beliefs, I think length extrapolation is a capability that LLM should have, so I have been actively exploring this.

lucidrains commented 1 year ago

hey Jianlin! yes, i did evaluate for length extrapolation, up to 4x the length. i will try again with all your new findings soon when i set aside some time

lucidrains commented 1 year ago

but i can safely assure you that one probably cannot hope to extrapolate to unbounded lengths without some amount of fine tuning. not the amount that you are looking for, imo

Based on the current results, I agree with this. But out of anthropomorphic beliefs, I think length extrapolation is a capability that LLM should have, so I have been actively exploring this.

yes indeed - i just know the problem that Jingyu faces is a bit different than language, with lengths that are at a different order of magnitude

Jingyu-Fan commented 1 year ago

but i can safely assure you that one probably cannot hope to extrapolate to unbounded lengths without some amount of fine tuning. not the amount that you are looking for, imo

Based on the current results, I agree with this. But out of anthropomorphic beliefs, I think length extrapolation is a capability that LLM should have, so I have been actively exploring this.

@bojone It is amazing to see that you showed up here lol. I truly enjoy reading your blogs! I am very intrigued by your explorations and please keep them on! :)

Here, Phil was actually giving me suggestions based on the nature of the problem I am working on, which is not LLM. But I hope your solutions can also be generalized to more contexts and will have a shot.

bojone commented 1 year ago

@everyone:

I found some English corpora and calculated the loss on llama2-13b, the results are quite encouraging. ReRoPE's performance at training length (4k) has hardly decreased, and it possesses the ideal property of "longer context, lower loss".

Here are my results:

RoPE-4k(original llama2): 1.4967308044433594
RoPE-8k(original llama2): 8.861469268798828

NTK-RoPE-4k(not dynamic): 1.608105926513672
NTK-RoPE-8k(not dynamic): 1.5417398071289063
NTK-RoPE-16k(not dynamic): 1.5162832641601562

ReRoPE-w1024-4k: 1.4995654296875
ReRoPE-w1024-8k: 1.42667236328125
ReRoPE-w1024-16k: 1.4001029095246424
lucidrains commented 1 year ago

i really need to check in my experimental code for length extrapolation, so i can quickly evaluate this

set aside tomorrow morning to give rerope a try

lucidrains commented 1 year ago

reopening as a reminder

lucidrains commented 1 year ago

@Jingyu-Fan is your model bidirectional or autoregressive? i'm assuming it is bidirectional (Encoder instead of Decoder)

Jingyu-Fan commented 1 year ago

@Jingyu-Fan is your model bidirectional or autoregressive? i'm assuming it is bidirectional (Encoder instead of Decoder)

yeah, it is an encoder-only model.

bojone commented 1 year ago

i really need to check in my experimental code for length extrapolation, so i can quickly evaluate this

set aside tomorrow morning to give rerope a try

My test data and scripts have all been synced to https://github.com/bojone/rerope/. You're welcome to refer to them, and I look forward to your guidance.

lucidrains commented 1 year ago

@Jingyu-Fan i just realized i won't have access to my deep learning machine tomorrow

maybe i'll get rerope checked into a branch this afternoon and you can quickly test it out on your experimental setup, with the caveat that if you hit a negative result, maybe the technique is decoder specific

lucidrains commented 1 year ago

@bojone should rerope also work for bidirectional attention? (haven't read it yet)

bojone commented 1 year ago

@bojone should rerope also work for bidirectional attention? (haven't read it yet)

I believe it should be applicable to a bidirectional encoder, but the current implementation needs modification (i.e., change to clip(relative_pos, -w, w) and all tokens share the same log L )

lucidrains commented 1 year ago

@bojone i see

does the technique rely on a scale that is a function of sequence length? flash attention is not designed to take in different scales per row. Jingyu is dealing with sequences potentially with up to a million tokens, so i'm afraid flash attention is a necessity

do you still see the effects even with log(L) scaling ablated?

bojone commented 1 year ago

@lucidrains In CausalLM, logn scaling can be directly multiplied to the queries, which does not affect our use of flash attention.

However, unfortunately, ReRoPE is currently not compatible with flash attention. But this is a problem with ReRoPE itself and has nothing to do with logn scaling.

lucidrains commented 1 year ago

@bojone oh yes, you are right that logn scaling can be applied on queries :pray: forgot about that

regardless, i should check in a simple length extrapolation experiment script, as it has become a hot research topic

lucidrains commented 1 year ago

@bojone congratulations on rotary embeddings btw :smile:

what a success the technique has become

bojone commented 1 year ago

@lucidrains Thank you. Indeed, I didn't expect that an interesting idea in the beginning has now become one of the mainstream position encodings. I am proud of it, but also puzzled - why can RoPE work better, why can ReRoPE successfully extrapolate, I still have no clues to these questions.

lucidrains commented 1 year ago
Screen Shot 2023-08-24 at 9 35 44 AM

haha, this trick got utilized in today's open sourced code llama, and got a nod from the Meta researchers!

lucidrains commented 1 year ago

somebody tell /u/bloc97!

lucidrains commented 1 year ago

@Jingyu-Fan btw, someone over at the enformer-pytorch repo is also interested in what you are working on (or at least headed down a similar path). maybe room to collaborate

edit: and another researcher but he's going for another efficient attention approach

lucidrains commented 1 year ago

i'm going to close this, as it is done