lbcb-sci / RiNALMo

RiboNucleic Acid (RNA) Language Model
https://sikic-lab.github.io/
Apache License 2.0
41 stars 6 forks source link

No support for flash-attn transformer model #12

Closed pkubgmlixs closed 2 days ago

pkubgmlixs commented 4 days ago

Hi there, Have you ever pre-trained the model which was built from scratch and didn't use flash-attn, as recorded in your code? The case is my cuda version is 11.4 and doesn't support flash-attn, such that the model weights in "rinalmo_giga_pretrained.pt", which is built in flash-attn mode, are not compatible with the model built from scratch. I'm trying to modify the plain model structure to be consistent with flash-attn model and therefore load flash-attn weights onto my plain model, but this is nontrivial for me and I can't assure it will finally work. Any suggestions?

pkubgmlixs commented 4 days ago

It finally works by slightly modifing the plain model structure (MultiHeadSelfAttention and RotaryPositionEmbedding class). It seems the remaining parts are all the same for both the plain and flash-attn model. I reason the modified plain model embedded with flash-attn weights will safely work fine. Any comments? Thanks!

retiro commented 3 days ago

Hi,

we have used FlashAttention2 since it enormously speeds up RiNALMo pre-training. Thus, unfortunately, we have not trained nor tested the model without FlashAttention2.

However, when we wanted to extract the attention maps, we used the _dot_productattention function https://github.com/lbcb-sci/RiNALMo/blob/da6572dc22617f95f7c657e79b2ea126dc7cf979/rinalmo/model/attention.py#L195 so you might adjust this code for your purpose, as well as the RotaryPositionEmbedding function https://github.com/lbcb-sci/RiNALMo/blob/da6572dc22617f95f7c657e79b2ea126dc7cf979/rinalmo/model/rope.py#L16 It should be possible to adapt the code to use the published weights without FlashAttention2 and we might do it in the future, but currently it is not high on our priority list. I hope this helps you.

pkubgmlixs commented 2 days ago

Thank you for the suggestions!