MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
212 stars 49 forks source link

Wrong tokens at the end of the sequence #63

Open pabloacera opened 5 months ago

pabloacera commented 5 months ago

Hi,

Thanks for making the model available. I have been playing with the model and realized that usually when making prediction of a sequence of DNA, usually the last token is not the one in the original sequence. The predictions usually have some extra nucleotides at the end of the sequence.

Am i missing something? Is this the expected behavior? Is there a expected nucleotide input length which fix this behavior?


sequences = ['TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGAGGAGCAAAATCTGTCCGATCTACTTTCCCCAATCTGTGAAGTAGCTAATAATATAGAACAAAATGCTCAAGAG',        'ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATG',
'GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACACTGCCAAGCAGCTCATCCCGGTCATCAGCGACTATCTCCTCTCCAACCTGAAAGGGATCGCCCA']

for dna in sequences:
    dna = dna[:128]

    inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
    hidden_states = model(inputs)

    logits = hidden_states.logits

    # Apply softmax to convert logits to probabilities
    probabilities = softmax(logits, dim=-1)

    # Choose the most likely token for each position
    predicted_token_ids = torch.argmax(probabilities, dim=-1)

    print('original tokens', inputs)
    print('predicted tokens', predicted_token_ids)
    print()

    # Convert these token ids back to nucleotides
    predicted_sequences = [tokenizer.decode(token_ids) for token_ids in predicted_token_ids[:,1:]]
    original = [tokenizer.decode(token_ids) for token_ids in inputs]

    print('Original', dna)
    print('Predicted',' '.join(predicted_sequences).replace(' ', ''))
    print()

original tokens tensor([[   1,   11,   45,  316, 1823,   48,  776,   86,   67,  330,  583, 1867,
           95,  105,  173,   60,  162, 3713, 2030,   13,  306,  922,   80,  438,
           70,  609,   50,    2]])
predicted tokens tensor([[ 371,   11,   45,  316, 1823,   48,  776,   86,   67,  330,  583, 1867,
           95,  105,  173,   60,  162, 3713, 2030,   13,  306,  922,   80,  438,
           70,  609,   50,  198]])

Original TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGA
Predicted TGGAAAGTTGGGGACAAATGTTCTGCCATTTGGTCAGAAGACGGTTGCATTTACCCAGCTACCATTGCTTCAATTGATTTTAAGAGAGAAACCTGTGTTGTGGTTTACACTGGATATGGAAATAGAGACAGTT

original tokens tensor([[   1,    5,  109,  906,   34,  209,  902,   16,  410,  149, 2659,  590,
          157,   57,   35, 2368,  224,   35,  246,   30,  105,   22,  236, 2463,
           70,    2]])
predicted tokens tensor([[   5,    5,  109,  906,   34,  209,  902,   16,  410,  149, 2659,  590,
          157,   57,   35, 2368,  224,   35,  246,   30,  105,   22,  236, 2463,
           70,   82]])

Original ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATG
Predicted ATAATTCCCCCACCACCTCCCATATGTCCAGATTCTCTTGATGATGCTGATGCTTTGGGAAGTATGTTAATTTCATGGTACATGAGTGGCTATCATACTGGCTATTATATGTCCA

original tokens tensor([[   1,  225,  136,   30,  708,  192, 1066,  192, 1717,   32,  118,  591,
         2310,   74,   95,  253,  793,   36,  335,   72,  578,   88, 2621,  215,
           93,   74,  438,   93,   12,    6,    2]])
predicted tokens tensor([[  13,  225,  136,   30,  708,  192, 1066,  192, 1717,   32,  118,  591,
         2310,   74,   95,  253,  793,   36,  335,   72,  578,   88, 2621,  215,
           93,   74,  438,   93,   12,    6,   92]])

Original GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACAC
Predicted GACAAGGCCGTGGCGGAGCCTGTCAGCCGCCTGCTGGAGAGCACGCTCAGGAGCAGCCACCTGCCCAGCAGGGTTGGAGCCCTGCACGGCGTCCTCTATGTGCTGGAGTGCGACCTGCTGGACGACACGGGG