instadeepai / nucleotide-transformer

🧬 Nucleotide Transformer: Building and Evaluating Robust Foundation Models for Human Genomics
https://www.biorxiv.org/content/10.1101/2023.01.11.523679v2
Other
442 stars 51 forks source link

error while executing the Google Colab inference_nt notebook #66

Closed amitpande74 closed 3 months ago

amitpande74 commented 4 months ago

I am trying to tailor the script for the Gene of my interest. Repeatedly, I am getting this error: ValueError: Input length must be divisible by the 2 to the power of number of poolign layers.

I read your article as well, but nowhere I found the sequence length accepted by your model. Could you kindly help?

dallatt commented 3 months ago

Hello,

Because of the convolutions in the model, the input sequences are expected to respect this criterion. This means that a given sequence needs tohave a sequence length that is a multiple of 6, in order to be correctly tokenized by the 6-mer tokenizer (you can find more information about this in the README.md) and that each token corresponds to 6 nucleotides. The tokenized sequence length must then respect the criterion of being a dividible by 4 because of the convolutional layers.

All in all, the sequence length inputted to the model needs to be dividible by 24 (6*4).

I hope this helps, Hugo

amitpande74 commented 2 months ago

Hello Hugo,

This is what I wrote

# -*- coding: utf-8 -*-
"""inference_segment_nt.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/#fileId=https%3A//huggingface.co/InstaDeepAI/segment_nt/blob/main/inference_segment_nt.ipynb

# Inference with Segment-NT models

## Installation and imports
"""

!pip install biopython

from Bio import SeqIO
import gzip
import numpy as np
import transformers
from transformers import AutoTokenizer, AutoModel
import torch
import seaborn as sns
from typing import List
import matplotlib.pyplot as plt

"""## Download the model
The following cell allows you to download the config and the model of one of the Segment-NT models.
"""

# Load model directly
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)
model = AutoModel.from_pretrained("InstaDeepAI/segment_nt", trust_remote_code=True)

"""# Define function that plots the probabilities"""

# seaborn settings
sns.set_style("whitegrid")
sns.set_context(
    "notebook",
    font_scale=1,
    rc={
        "font.size": 14,
        "axes.titlesize": 18,
        "axes.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "legend.fontsize": 16,
        }
)

plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

# set colors
colors = sns.color_palette("Set2").as_hex()
colors2 = sns.color_palette("husl").as_hex()

# Rearrange order of the features to match Fig.3 from the paper
features_rearranged = [
 'protein_coding_gene',
 'lncRNA',
 '5UTR',
 '3UTR',
 'exon',
 'intron',
 'splice_donor',
 'splice_acceptor',
 'promoter_Tissue_specific',
 'promoter_Tissue_invariant',
 'enhancer_Tissue_specific',
 'enhancer_Tissue_invariant',
 'CTCF-bound',
 'polyA_signal',
]

def plot_features(
    predicted_probabilities_all,
    seq_length: int,
    features: List[str],
    order_to_plot: List[str],
    fig_width=8,
):
    """
    Function to plot labels and predicted probabilities.

    Args:
        predicted_probabilities_all: Probabilities per genomic feature for each
            nucleotide in the DNA sequence.
        seq_length: DNA sequence length.
        feature: Genomic features to plot.
        order_to_plot: Order in which to plot the genomic features. This needs to be
            specified in order to match the order presented in the Fig.3 of the paper
        fig_width: Width of the figure
    """

    sc = 1.8
    n_panels = 7

    _, axes = plt.subplots(n_panels, 1, figsize=(fig_width * sc, (n_panels + 4) * sc))

    for n, feat in enumerate(order_to_plot):
        feat_id = features.index(feat)
        prob_dist = predicted_probabilities_all[:, feat_id]

        # Use the appropriate subplot
        ax = axes[n // 2]

        try:
            id_color = colors[feat_id]
        except:
            id_color = colors2[feat_id - 8]
        ax.plot(
            prob_dist,
            color=id_color,
            label=feat,
            linestyle="-",
            linewidth=1.5,
        )
        ax.set_xlim(0, seq_length)
        ax.grid(False)
        ax.spines['bottom'].set_color('black')
        ax.spines['top'].set_color('black')
        ax.spines['right'].set_color('black')
        ax.spines['left'].set_color('black')

    for a in range (0,n_panels):
        axes[a].set_ylim(0, 1.05)
        axes[a].set_ylabel("Prob.")
        axes[a].legend(loc="upper left", bbox_to_anchor=(1, 1), borderaxespad=0)
        if a != (n_panels-1):
            axes[a].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False)

    # Set common x-axis label
    axes[-1].set_xlabel("Nucleotides")
    axes[n_panels-1].grid(False)
    axes[n_panels-1].tick_params(axis='y', which='both', left=True, right=False, labelleft=True, labelright=False)

    axes[0].set_title("Probabilities predicted over all genomics features", fontweight="bold")

    plt.show()

"""## Combine exon and intron sequences of ESRG gene"""

# Your sequences
# Your sequences
exon_sequence = """GCTGACTCTCTTTTCGGACTCAGCCCGCCTGCACCCAGGTGAAATAAACAGCCTCGTTGCTCACACAAAGCCTGTTTGGTGGTCTCTTCACACGGACGCGCATGAAATTTGGTGCCGTGACTCGGATCGGGGGACCTCCCTTGGGAGATCAATCCCCTGTCCTCCTGCTCTTTGCTCCGTGAGAAAGATCCACCTACGACCTCAGGTCCTCAGACCAACCAGCCCAAGAAACATCTCACCAATTTCAAATCCGGTAAGCGGCCTCTTTTTACTCTGTTCTCCAACCTCCCTCACTATCCCTCAACCTCTTTCTCCTTTCAATCTTGGCGCCACACTTCAATCTCTCCCTTCTCTTAATTTCAATTCCTTTCATTCTCTGGTAGAGACAAAAGAGACATGTTTTATCCGTGAACCCAAAACTCCGGCGCCGGTCACGGACTGGGAAGGCAGTCTTCCCTTGGTGTTTAATCATTGCAGGGACGCCTCTCTGATTTCACGTTTCAGACCACGCAGGGATGCCTGCCTTGGTCCTTCACCCTTAGCGGCAAGTCCCGCTTTCCTGGGGCAGGGGCAAGTACCCCTCAACCCCTTCTCCTTCACCCTTAGCGGCAAGTCCCGCTTTTCTGGGGCAGGGGCAAGTACCCCTCAACCCCTTCTCCTTCACCCTTAGCAGCAAGTCCCGCTTTCCTAGGGGGCAAGAACCCCCCAATCGCTTATTTTCACGCCCCAACCTCTTATCTCTGTGCCCCAATCCCTTATTTCCACGCCCCAATCTCTTATCTCTGCGCCCCAATCCCTTATTTCCGCGCCCCAACCCTTTCTCTGCTTTTCTGGAGGGGAAGAAACCCCCACCCCTTCTCCGTGTCTCTACTCTTTTCTCTGGGCTTGCCTCCTTCACTATGGGCAAGCTTCCACCTTCCATTCCTTTCTTCTCCCTTAGCATGTATTCTTAAGAACTTAAAATCTCTTCAATTCTCACCTGACCTAAAATCTAAGCGTCTTATTTTCTTCTGCAATGCCACTTGACCCCAATACAAACTCAACAGTAGTTCCAAATAGCCAGAAAATGGCACTTTCAATTTTTCCACCCTACAAGATCTAAATAATTCTTGGCGTAAAATGGGCAAATGGTGTGAGGTGCCTGACGTCCAGGCATTCTTTTACACATCAGTCCCTTCCTAGTCTCTGTGCCCAGTGCAACTCGTCCCAAATCTTCCTTCTTTCCCTCCCGCCTGTCCCCTCAGTACCAACCCCAAGCGTCACTGAGTCTTTCTAATCTTCCTTTTCTACAGACCCATCTGACCTCTCCCTTCCTCCCCAGGCTGCTCCTTGCCAGGCCGAGCTAGGTCCCAATTCTTCCTCAGCCTCTGCTCCTCCACCCTATAATCTTTTTATCACCTCCCCTCCTCACACCTGCTCCGGCTTACAGTTTCATTCCGTGACTAGCCCTCCCCGACCTGCCCAGCAATTTATTCTTAAAAAGGTGGCTGGAGCTAAACGCATAGTCAAGGTTAATGCTCCTTTTTCTTTATCCCAAATCAGATAGTGTTTAGGCTCTTTTTCATCAAATATAAAAATCTAGCCCAGTTCATGGCTCGTTTGGCAGCAACCCTAAGACACTTTACAGCCCTAGCCCCTAAAAGGTCAAAAGGCCATCTTATTCTCAATATACATTTTATTACCCAATCTGCTCCCGACATTAAATAAAACTCCAAAAACTGGAATCTGGCCCTCAAACCCCACAACAGGACTTAATTAACCTCACCTTCAAGGTGTGAAATAACAGAAAAAAGTTGCAATTCCTTGCCTCCACTGTGAGACAAACCCCAGCCACATCTCCAGCACACAAGAACTTCCAAACGCCTGAACTGTAGCAGCCAGACGTTTCTCCAGAACCTCCTCCCCCAGGAACTTGCTACACATGCCGGAAATCTGGCCACTGGGCCAAGGAACGCCCGCAGCCCGGGATTCCTCCTAAGCCGCGTCCCATCTGTGTGGGACCCCACTGAAAATCGGACTGTTCAACTCACCTGGCAGCCACTCCCAGAGCTCCTGGAACTCTGGCCCAAGGTTCTCTGACTGACTCCTTCTTGGCTTACTGGCTGAAGACTGACGCTGCCTGATCGCCTCAGAAGCCCCGCAGACCATCATGGACGCCGAGCTTTAG"""
intron_sequence = """GTAACTCACAGTGGAGGGTAAGTCCGTCCCCTTCTTAATCAATACGGAGGCTACCCACTCCACATTACCTTCTTTTCAAGGGTCTGTTTCCCTTGCCTCCATAACTGTTGTGAGTATTGACAGCCAGGCTTCTAAACCTCTTAAAACTCCCCAACTCTGGTGCCAACTTAGACAATACTCTTTAAAGCACTCCTTTTTAGTTATCCCCACCTGCCCAGTTCCCTTATTAGGCTGAGACACTTTAACTAAATTGTCTGCTTCCCTGACTATTCCTGGACTACAGCTATATCTCATTGCCGCCCTTCTTCCCAATCAAAAGCCTCCTTTGCGTCCTCCTCTTGTATCCCCCCACCTTAACCCACAAGTATAAGATACGTCTACTCCCTCCTTGGTGACCGATCATGCACCCCTTACCATCTCATTAAAACCTAATCACCCTTACCCT

# Assuming exon_sequence and intron_sequence are already defined and combined_sequence is the result
combined_sequence = exon_sequence + intron_sequence

"""## Ensure sequence length is divisible by 24"""

def pad_sequence(sequence, divisor=24):
    padding_needed = (divisor - len(sequence) % divisor) % divisor
    if padding_needed > 0:
        padded_sequence = sequence + 'N' * padding_needed
    else:
        padded_sequence = sequence
    return padded_sequence

padded_sequence = pad_sequence(combined_sequence)

"""# Tokenize the DNA sequence"""

# Adjust the max_num_dna_tokens to ensure the length is valid
max_num_dna_tokens = len(padded_sequence) // 6  # 6-mer tokenization

# Tokenize and pad the sequence
tokens = tokenizer.batch_encode_plus(
    [padded_sequence], 
    return_tensors="pt", 
    padding="max_length", 
    max_length=max_num_dna_tokens
)["input_ids"]

# Ensure the tokenized length is divisible by 4
def ensure_divisible_by(tokens, divisor):
    padding_needed = (divisor - tokens.shape[1] % divisor) % divisor
    if padding_needed > 0:
        tokens = torch.nn.functional.pad(tokens, (0, padding_needed), value=tokenizer.pad_token_id)
    return tokens

tokens = ensure_divisible_by(tokens, 4)

"""## Infer on the resulting batch"""

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.cuda()

# Infer
tokens = tokens.cuda()
attention_mask = (tokens != tokenizer.pad_token_id).cuda()
with torch.no_grad():
  outs = model(
      tokens,
      attention_mask=attention_mask,
  )

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(torch.nn.functional.softmax(logits, dim=-1).cpu())[...,-1]

del outs
del tokens, attention_mask

"""## Plot the probabilities for 14 genomic features along this DNA sequence"""

plot_features(
    probabilities[0],
    probabilities.shape[-2],
    fig_width=20,
    features=model.config.features,
    order_to_plot=features_rearranged
)

Yet the error remains the same

ValueError: Input length must be divisible by the 2 to the power of number of poolign layers.

Kindly look into it.