vkola-lab / peds2019

Quantifying the nativeness of antibody sequences using long short-term memory networks
MIT License
16 stars 7 forks source link

Classification model in figure 2b #7

Open wjs20 opened 3 years ago

wjs20 commented 3 years ago

Hi

In figure 2b of 'quantifying antibody nativeness' where you plot the roc-auc scores for the human vs non-human classifier, do you modify the LSTM architecture with a new fc layer on top? or do you just train an sklearn logistic regression model on the output scores?

Thanks!

SenyorDrew commented 3 years ago

Hi, No classifier was used for figure 2B. Here's how figure 2B was generated for mouse (for example): 1) combine test set of mouse sequences and test set of human sequences (keeping track of the source organism for each sequence) 2) Calculate scores (Y_pred) for each sequence using the LSTM model (trained on a training set of human sequences) 3) Generate an ROC plot using Y_pred above and Y_true (whether the organism is from human or not)

psuedo-code: fpr, tpr, thresholds = sklearn.metrics.roc_curve(Y_true, Y_pred, pos_label=1) plt.plot(fpr, tpr)

Let me know if that helps

wjs20 commented 3 years ago

Hi

I've been having a go at replicating your results using some camel, mouse and macaque sequences from the OAS database. I'm getting good roc_auc scores (0.98, 0.97, 0.80), but I can't seem to find the same relationship between human germline identity and LSTM score that you report in your paper. I used the anarci python package to align the animal sequences to human germline but I must be doing something wrong. can your tell me how you did it?

Thanks

SenyorDrew commented 3 years ago

Hi wjs20, I'm assuming you're referring to Figure 3. For the germline identity, we use the value that is reported by ANARCI. ANARCI reports a % sequence identity when it performs its annotations (I believe they label it "v_germ_identity"). You can restrict its search to a specific species, so in this case for any input sequence, such as a mouse antibody sequence, we do the annotation with ANARCI and tell it to use "human" for the species.

If you're still having trouble reproducing Figure 3, could you post your distribution of germline identities for mouse, and your distribution of lstm scores for mouse? For example, for that plot our median germline identity for mouse is 0.673

SenyorDrew commented 3 years ago

Just to follow up with my previous comment and to help address this issue, could you let me know what germline_sequence_id you get for the following mouse sequence:

QVQLKESGPGLVAPSQSLSITCTVSGFSLTGYGVNWVRQPPGKGLEWLGMIWGDGSTDYNSALKSRLSISKDNSKSQVFLKMNSLQTDDTARYYCARDGEDYDATFYWYFDVWGAGTTVTVSS

(This is the first mouse sequence in our testing set).

I get v_germ_identity=0.643 from ANARCI

wjs20 commented 3 years ago

This is my distribution of lstm scores across all species I looked at image

This is a plot of mouse LSTM scores against germline identity image

I got this output for the sequence you provided 'v_gene': [('human', 'IGHV4-38-2*02'), 0.6428571428571429]},

This is code I used

lstm_scores = pd.read_csv(DATA_PATH/'lstm_scores.csv')
seqs = [('seq'+str(i), seq) for i, seq in enumerate(lstm_scores.seqs.values)]

# test seq
seq = [('seq1', 'QVQLKESGPGLVAPSQSLSITCTVSGFSLTGYGVNWVRQPPGKGLEWLGMIWGDGSTDYNSALKSRLSISKDNSKSQVFLKMNSLQTDDTARYYCARDGEDYDATFYWYFDVWGAGTTVTVSS')]

results = anarci(seq, scheme="aho", output=False, assign_germline=True, allowed_species=['human'])
numbering, alignment_details, hit_tables = results

alignment_details

Thanks for having a look

SenyorDrew commented 3 years ago

Thanks. That's interesting, it looks like your distribution of LSTM scores is roughly in line with the publication, and your distribution of mouse germline identities is also in line with the publication. I wonder if the discrepancy comes from a bug that I think you reported where the ordering of output LSTM scores did not match the order of input sequences.

wjs20 commented 3 years ago

I thought the ordering might have been muddled at some point so I did some checks on 10,000 camel sequences I've been looking at

The anarci alignment function appears to preserve sequence order

with open(DATA_PATH/'li'/'valid.txt') as f:
    seqs = [s.strip() for s in f.readlines()]

seqs = [('seq'+str(i), seq) for i, seq in enumerate(seqs)]

numbering, alignment_details, hit_tables = anarci(seqs, scheme="aho", output=False, assign_germline=True, allowed_species=['human'])

def get_seq(num):
    return ''.join([o[1] for o in num[0][0] if o[1] != '-'])

def get_alignment_score(alignment):
    return alignment[0]['germlines']['v_gene'][1]

import random
random.sample([seq[1]==get_seq(o) for seq, o in zip(seqs, numbering)], 100)

## order maintained ##
[True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True,
 True....]

alignment_scores = [(get_seq(s), get_alignment_score(a)) for s, a in zip(numbering, alignment_details)]
alignment_scores[:10]

[('HVQLVESGGGSVQAGGSLRLSCAASGYTISSNCMVWFRQAPGKEREGVASIYTGGGSPYYADSVKGRFTISQDNAKNTVYLHMDNVKAEDTAMYYCAADGNGGGCAGPILDYWGQGTQVTVS',
  0.7422680412371134),
 ('VQLVESGGGSVQDGGSLRLSCAASGFIFSNYWMHWARQAPGKGLEWVSSTNSRGVSYAVQAVKGRFTISRDNAKNTTSLQMNSLKAEDTATYYCAADGLSGSNYDSVPYFAYSGQGTQVTVS',
  0.7551020408163265),
 ('DVQLVESGGGSVQAGGSLRLSCVASGYSYSKYCMAWFRQGPGKERDRIASIHSDGATSYSDSVKGRFTISKDNPKSTLDLQMNELNPEDTGKYYCAASVLEGDVRCGTPTWRGYFNHFAYWGRGTQVTVS',
  0.6530612244897959),
 ('EVQLVESGGGSVQAGGSLTLSCTASGYSYSYSSYCMGWFRQAPGKEREVVARIESDSTTDYADSVKGRFTISRDSAKNTVYLQMNNLQPEDTATYYCAEGRGSRGEHCYSLNYWGQGTQVTVS',
  0.7346938775510204) ...]

import json

with open(DATA_PATH/'camel_alignment.json', 'w') as f:
    json.dump(alignment_scores, f)

I ran model.eval() on the same data to generate lstm scores and checked for a correlation but did not find any

def load_peds_model():
    MODEL = 'model_tmp.npy'
    model_loaded = ModelLSTM()
    model_loaded.load(fn=MODEL_PATH/MODEL)
    return model_loaded

model = load_peds_model()

cam_lstm = model.eval(DATA_PATH/'li'/'valid.txt')

cam_lstm[:5]
array([1.245738 , 1.6254249, 0.9844174, 1.4797609, 1.0029281],
      dtype=float32)

with open(DATA_PATH/'camel_alignment.json') as f:
    cam_align = json.load(f)

cam_align[:5]
[['HVQLVESGGGSVQAGGSLRLSCAASGYTISSNCMVWFRQAPGKEREGVASIYTGGGSPYYADSVKGRFTISQDNAKNTVYLHMDNVKAEDTAMYYCAADGNGGGCAGPILDYWGQGTQVTVS',
  0.7422680412371134],
 ['VQLVESGGGSVQDGGSLRLSCAASGFIFSNYWMHWARQAPGKGLEWVSSTNSRGVSYAVQAVKGRFTISRDNAKNTTSLQMNSLKAEDTATYYCAADGLSGSNYDSVPYFAYSGQGTQVTVS',
  0.7551020408163265],
 ['DVQLVESGGGSVQAGGSLRLSCVASGYSYSKYCMAWFRQGPGKERDRIASIHSDGATSYSDSVKGRFTISKDNPKSTLDLQMNELNPEDTGKYYCAASVLEGDVRCGTPTWRGYFNHFAYWGRGTQVTVS',
  0.6530612244897959]...]

cam_align_scores = np.array([s[1] for s in cam_align])

cam_lstm.shape, cam_align_scores.shape
((10000,), (10000,))

np.corrcoef(cam_lstm, cam_align_scores)
array([[ 1.        , -0.00233673],
       [-0.00233673,  1.        ]])

As you can see I found no significant correlation between lstm-scores and germline identity

I had a look in model.py to see if the data was being shuffled by the .eval() method and it doesn't seem to be. The method does not output the sequences alongside the lstm scores so I can't do an equality check like I did with the anarci function. Do you have any idea what might be going on?

Thanks!

SenyorDrew commented 3 years ago

Nothing immediately comes to mind - not sure if any of the other authors have ideas. I will try to test out a couple things on my end and will report here if I find anything.

tanggis commented 3 years ago

@wjs20, can you double-check if you working with the latest version of the code? The bug that @SenyorDrew mentioned above was reported here #4 , where model.eval() used to shuffle the sequences and generate scores in a random order every time. The latest commit of December should have fixed it.

wjs20 commented 3 years ago

looks like shuffle was set to true in the dataloaders inside the eval() method

def eval(self, fn, batch_size=512):        
        # dataset and dataset loader
        data = ProteinSeqDataset(fn, self.gapped)
        if batch_size == -1: batch_size = len(data)
        dataloader = torch.utils.data.DataLoader(data, batch_size, True, collate_fn=collate_fn)

        self.nn.eval()
        scores = np.zeros(len(data), dtype=np.float32)
        sys.stdout.flush()
        with torch.set_grad_enabled(False):
            with tqdm(total=len(data), ascii=True, unit='seq', bar_format='{l_bar}{r_bar}') as pbar:
                for n, (batch, batch_flatten) in enumerate(dataloader):
                    actual_batch_size = len(batch)  # last iteration may contain less sequences
                    seq_len = [len(seq) for seq in batch]
                    seq_len_cumsum = np.cumsum(seq_len)
                    out = self.nn(batch, aa2id_i[self.gapped]).data.cpu().numpy()
                    out = np.split(out, seq_len_cumsum)[:-1]
                    batch_scores = []
                    for i in range(actual_batch_size):
                        pos_scores = []
                        for j in range(seq_len[i]):
                            pos_scores.append(out[i][j, batch[i][j]])
                        batch_scores.append(-sum(pos_scores) / seq_len[i])    
                    scores[n*batch_size:(n+1)*batch_size] = batch_scores
                    pbar.update(len(batch))
        return scores

I must have cloned an old version of the repo and not updated the changes. I'll do that now and see if it helps.

Thanks for looking into it for me.

SenyorDrew commented 3 years ago

Hi @wjs20 - I went back and did an independent verification. I have my code wrapped in higher level functions, so it's a bit hard to share, but I've shared the plots and data-table used to make this plot. Here's what I did:

  1. Took the first 1,000 mouse sequences in the testing set
  2. In a for loop I calculated the LSTM score and the ANARCI germline_id for each sequence 1 at a time. This should get rid of any potential issues around score shuffling.

Here's what my plot of v_germ_id (closest human seq id from ANARACI) looks like compared to LSTM score (trained on a set of human sequences):

mouse_lstm_id_corr

I've attached a .xlsx mouse_lstm_scores.xlsx table of the 1K mouse sequences, the lstm_score I get with my model, and the v_germ_id from ANARCI. I'd be curious to see what values you get for each individual sequence to see where the discrepancy comes from.

wjs20 commented 3 years ago

Sequence scores without shuffling...

image

problem sovled!

I think you may need to update your repo though. I cloned a current version of the repo into my colab notebook to check if the bug was still in the model.eval() method and it was

Signature: model.eval(fn, batch_size=512)
Source:   
    def eval(self, fn, batch_size=512):        
        # dataset and dataset loader
        data = ProteinSeqDataset(fn, self.gapped)
        if batch_size == -1: batch_size = len(data)
        dataloader = torch.utils.data.DataLoader(data, batch_size, True, collate_fn=collate_fn)

        self.nn.eval()
        scores = np.zeros(len(data), dtype=np.float32)
        sys.stdout.flush()
        with torch.set_grad_enabled(False):
            with tqdm(total=len(data), ascii=True, unit='seq', bar_format='{l_bar}{r_bar}') as pbar:
                for n, (batch, batch_flatten) in enumerate(dataloader):
                    actual_batch_size = len(batch)  # last iteration may contain less sequences
                    seq_len = [len(seq) for seq in batch]
                    seq_len_cumsum = np.cumsum(seq_len)
                    out = self.nn(batch, aa2id_i[self.gapped]).data.cpu().numpy()
                    out = np.split(out, seq_len_cumsum)[:-1]
                    batch_scores = []
                    for i in range(actual_batch_size):
                        pos_scores = []
                        for j in range(seq_len[i]):
                            pos_scores.append(out[i][j, batch[i][j]])
                        batch_scores.append(-sum(pos_scores) / seq_len[i])    
                    scores[n*batch_size:(n+1)*batch_size] = batch_scores
                    pbar.update(len(batch))
        return scores
File:      /content/drive/MyDrive/peds2019/model.py
Type:      method

Theres not parameter in the .eval() method to turn of shuffling so I had to extract the code chunk into another function

def eval(model, fn, batch_size=512):        
    # dataset and dataset loader
    data = ProteinSeqDataset(fn, model.gapped)
    if batch_size == -1: batch_size = len(data)
    dataloader = torch.utils.data.DataLoader(data, batch_size, False, collate_fn=collate_fn)

    model.nn.eval()
    scores = np.zeros(len(data), dtype=np.float32)
    sys.stdout.flush()
    with torch.set_grad_enabled(False):
        with tqdm(total=len(data), ascii=True, unit='seq', bar_format='{l_bar}{r_bar}') as pbar:
            for n, (batch, batch_flatten) in enumerate(dataloader):
                actual_batch_size = len(batch)  # last iteration may contain less sequences
                seq_len = [len(seq) for seq in batch]
                seq_len_cumsum = np.cumsum(seq_len)
                out = model.nn(batch, aa2id_i[model.gapped]).data.cpu().numpy()
                out = np.split(out, seq_len_cumsum)[:-1]
                batch_scores = []
                for i in range(actual_batch_size):
                    pos_scores = []
                    for j in range(seq_len[i]):
                        pos_scores.append(out[i][j, batch[i][j]])
                    batch_scores.append(-sum(pos_scores) / seq_len[i])    
                scores[n*batch_size:(n+1)*batch_size] = batch_scores
                pbar.update(len(batch))
    return scores

Thanks for the help