tkzeng / Pangolin

Pangolin is a deep-learning method for predicting splice site strengths.
GNU General Public License v3.0
61 stars 32 forks source link

Inconsistency of predictions depending on batch size #22

Open neverov-am opened 3 months ago

neverov-am commented 3 months ago

Dear Authors,

Thank you for the great tool.

I want to implement an option to predict scores with batch size larger than 1. During my first tests I noticed, that the predictions differ depending on the batch size. Could you check what might be the reason for this behaviour of the model? Below, I provide the example variant (chr12-110435045-G-A), for which the score differs when it's predicted for the single variant and for the provided batch of size 4: 0.5400000214576721 in the original version against 0.5299999713897705 on the batch. I also provide my code to reproduce the issue. To make the question more compact, I give an example with a prediction mismatch for just one of the models.

import torch
import numpy as np
import pyfastx
from pkg_resources import resource_filename
from pangolin.model import *

###############################################################################################
test_variants = [
    'chr12-110435044-T-C',
    'chr12-110435044-T-G',
    'chr12-110435045-G-A',
    'chr12-110435045-G-C',
]

atol = 0.000001 # tolerance value to be used in np.allclose()
d = 50
reference_fasta_path = 'GRCh38.primary_assembly.genome.fa'
###############################################################################################
# the same as in the original version

IN_MAP = np.asarray([[0, 0, 0, 0],
                     [1, 0, 0, 0],
                     [0, 1, 0, 0],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])

def one_hot_encode(seq, strand):
    seq = seq.upper().replace('A', '1').replace('C', '2')
    seq = seq.replace('G', '3').replace('T', '4').replace('N', '0')
    if strand == '+':
        seq = np.asarray(list(map(int, list(seq))))
    elif strand == '-':
        seq = np.asarray(list(map(int, list(seq[::-1]))))
        seq = (5 - seq) % 5  # Reverse complement
    return IN_MAP[seq.astype('int8')]

models = []
for i in [0,2,4,6]:
    for j in range(1,4):
        model = Pangolin(L, W, AR)
        if torch.cuda.is_available():
            model.cuda()
            weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)))
        else:
            weights = torch.load(resource_filename(__name__,"models/final.%s.%s.3.v2" % (j, i)), map_location=torch.device('cpu'))
        model.load_state_dict(weights)
        model.eval()
        models.append(model)

###############################################################################################
# process variants

def prepare_variant_for_batch(lnum, chr, pos, ref, alt, fasta, d):

    seq = fasta[chr][pos-5001-d:pos+len(ref)+4999+d].seq

    ref_seq = seq
    alt_seq = seq[:5000+d] + alt + seq[5000+d+len(ref):]

    return ref_seq, alt_seq

fasta = pyfastx.Fasta(reference_fasta_path)

batch_chroms = []
batch_positions = []
batch_refs = []
batch_alts = []

for test_variant in test_variants:
    chr = test_variant.split('-')[0]
    pos = int(test_variant.split('-')[1])
    ref = test_variant.split('-')[2]
    alt = test_variant.split('-')[3]

    ref_seq, alt_seq = prepare_variant_for_batch(0, chr, pos, ref, alt, fasta, d)

    batch_chroms.append(chr)
    batch_positions.append(pos)
    batch_refs.append(ref_seq)
    batch_alts.append(alt_seq)

model = models[0]

strand = '-'

# predict on batch

encoded_refs = [] # store encoded reference sequences in a list
encoded_alts = [] # store encoded alternative sequences in a list

for i in range(len(batch_refs)):
    ref_seq = torch.from_numpy(one_hot_encode(batch_refs[i], strand).T).float()
    alt_seq = torch.from_numpy(one_hot_encode(batch_alts[i], strand).T).float()
    encoded_refs.append(ref_seq)
    encoded_alts.append(alt_seq)

batch_ref = torch.stack(encoded_refs) # create a tensor with multiple ref sequences
batch_alt = torch.stack(encoded_alts) # create a tensor with multiple alt sequences

if torch.cuda.is_available():
    batch_ref = batch_ref.to(torch.device("cuda"))
    batch_alt = batch_alt.to(torch.device("cuda"))

with torch.no_grad():
    pred_ref = model(batch_ref)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing
    pred_alt = model(batch_alt)[:,[1,4,7,10][j],:].cpu().numpy() # [0][[1,4,7,10][j],:].cpu().numpy() modify indexing

# predict single

i=2

ref_seq = one_hot_encode(batch_refs[i], strand).T
ref_seq = torch.from_numpy(np.expand_dims(ref_seq, axis=0)).float()
alt_seq = one_hot_encode(batch_alts[i], strand).T
alt_seq = torch.from_numpy(np.expand_dims(alt_seq, axis=0)).float()

if torch.cuda.is_available():
    ref_seq = ref_seq.to(torch.device("cuda"))
    alt_seq = alt_seq.to(torch.device("cuda"))

with torch.no_grad():
    pred_ref_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()
    pred_alt_single = model(ref_seq)[0][[1,4,7,10][j],:].cpu().numpy()

# compare

print(np.allclose(pred_ref_single, pred_ref[i], atol=atol)) # Switches from True to False between atol=0.00001 and atol=0.000001