lbcb-sci / RiNALMo

RiboNucleic Acid (RNA) Language Model
https://sikic-lab.github.io/
Apache License 2.0
43 stars 6 forks source link

How to understand the meaning of RNA representation #7

Closed ylzdmm closed 1 month ago

ylzdmm commented 1 month ago

Hello, thank you for your reply, I have solved it, but now I have a problem: I modified my test.py file as follows:

import torch
from rinalmo.pretrained import get_pretrained_model
DEVICE = "cuda:0"
model, alphabet = get_pretrained_model(model_name="rinalmo_giga_pretrained")
model.eval()
model = model.to(device=DEVICE)
seqs = ["CCCGGU"]
tokens = torch.tensor(alphabet.batch_tokenize(seqs), dtype=torch.int64, device=DEVICE)
with torch.no_grad(), torch.cuda.amp.autocast():
outputs = model(tokens)
for rep in outputs["representation"]:
print(rep.shape)

output: torch.Size([8, 1280])

if seqs = ["ACUUUGGCCA"] output: torch.Size([12, 1280])

if seqs = ["ACUUUGGCCA","CCCGGU"] output: torch.Size([12, 1280]) torch.Size([12, 1280])

It seems that the output dimension is determined by the maximum sequence length of the input(Every sequence begins with a [CLS] token,and ends with an [EOS] token), and the excess dimensions are filled according to your rules. Can I understand that each 1280 tensor represents a base? But according to your paper: an RNA sequence is tokenized and turned into a 1280 dimension vector using a learned input embedding model. How do I understand the meaning of this output, and how do I fix the sequence dimensions to facilitate my downstream tasks, such as predicting interactions between RNAs?

RJPenic commented 1 month ago

For the RNA sequence of N nucleotides, you will get N + 2 output representations. First output (index 0) represents CLS token and is often used as the sequence representation. Other outputs are nucleotide representations (according to their positions, e.g. index 5 represents the fifth nucleotide).