tttianhao / CLEAN

CLEAN: a contrastive learning model for high-quality functional prediction of proteins
MIT License
224 stars 44 forks source link

Cannot reproduce results in README #41

Closed pgmikhael closed 9 months ago

pgmikhael commented 10 months ago

Hi,

Thanks so much for this work, and for making the repo super nice and straightforward!

Before evaluating the model on a separate use case I have, I wanted to make sure I didn't do anything wrong when setting up the project, so I've been trying to evaluate the model per the README to ensure I get consistent results. However, I obtain very poor performance when calling inference.py on the test sets provided (price, new), so it would be great to understand what I've done wrong.

#### results on new

The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])
100%|███████████████████████████████████████████████| 5242/5242 [00:00<00:00, 7713.09it/s]
Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers
392it [00:04, 93.33it/s] 
############ EC calling results using maximum separation ############
---------------------------------------------------------------------------
>>> total samples: 392 | total ec: 177 
>>> precision: 0.0135 | recall: 0.0139| F1: 0.0136 | AUC: 0.507 
---------------------------------------------------------------------------

#### results on price

The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([149, 128])
100%|███████████████████████████████████████████████| 5242/5242 [00:00<00:00, 8758.51it/s]
Calculating eval distance map, between 149 test ids and 5242 train EC cluster centers
149it [00:00, 155.21it/s]
############ EC calling results using maximum separation ############
---------------------------------------------------------------------------
>>> total samples: 149 | total ec: 56 
>>> precision: 0.0 | recall: 0.0| F1: 0.0 | AUC: 0.5 
---------------------------------------------------------------------------

From my attempts to debug, this doesn't seem to be an issue of models/data processing. For instance, I compared the embedding of the first cluster (EC 2.7.10.2) from data/pretrained/100.pt with ones I manually recomputed, and obtained the same values (up to some numerical error). Specifically, I made a FASTA file using the sequences that are in EC 2.7.10.2, extracted their embeddings, then passed them through the pretrained model (data/pretrained/split100.pth). I compared these with what we get from calling get_cluster_center on the precomputed tensor. These appeared to be consistent. So, if the embeddings are calculated in a consistent manner, I'm not sure why the predictions are turning out to be wrong.

Python 3.10.4 
#### recalculate the first EC cluster embeddings

>>> import os, torch
>>> from CLEAN.utils import * 
>>> from CLEAN.distance_map import *
>>> from CLEAN.model import LayerNormNet 

>>> train_data = "split100"
>>> train_csv = pandas.read_csv('data/split100.csv', delimiter='\t')
>>> id_ec_train, ec_id_dict_train = get_ec_id_dict('data/split100.csv')
>>> list(ec_id_dict_train.keys())[0]
2.7.10.2

#### make fasta of sequences in 2.7.10.2
>>> with open("data/ec_2.7.10.2.fasta", "w") as f:           
>>>    for u in ec_id_dict_train['2.7.10.2']:
>>>        sequence = train_csv[train_csv['Entry'] == u].iloc[0].Sequence                                                                                                                                                            
>>>        f.write(f">{u}\n")                                                                                                                                                                                        
>>>        f.write(f"{sequence}\n")

#### calculate ESM embeddings
>>> retrive_esm1b_embedding('ec_2.7.10.2')                 

#### load the split100 model weights
>>> device = torch.device("cpu")
>>> dtype = torch.float32 
>>> model = LayerNormNet(512, 128, device, dtype)
>>> checkpoint = torch.load('./data/pretrained/'+ train_data +'.pth', map_location="cpu")                                                                                                                                
>>> model.load_state_dict(checkpoint) 
<All keys matched successfully>
>>> model.eval()   

#### calculate model embeddings
>>> esm_to_cat = [load_esm(id) for id in ec_id_dict_train['2.7.10.2']]
>>> esm_emb = torch.cat(esm_to_cat)
>>> model_emb = model(esm_emb)
>>> model_emb.mean(0)
tensor([ 0.5685, -0.2730,  1.3413, -0.0456,  0.5519, -0.5602,  0.4451,  0.3555,
        -0.3991,  0.8149, -0.7487,  0.8769, -0.0774, -1.2195, -0.3510,  0.3407,
         0.6934, -0.4897, -0.6785,  0.4822, -0.4403,  0.1503,  0.6215, -0.2650,
         1.0949,  0.4402,  0.4229,  1.4833,  0.2911, -2.0526, -1.0108,  0.8270,
         0.0103, -0.4964,  0.4265,  0.6308, -0.5499, -1.2762, -0.9738,  0.3144,
        -0.9146,  0.4415,  0.2395,  0.2096,  0.0948, -0.6719,  0.1269, -0.6432,
         1.3322,  0.8958,  0.2907,  1.5833,  1.6047,  0.0428, -0.1019, -0.1428,
         0.6814, -0.9868,  0.4500,  0.1788, -0.3415,  1.0227,  0.2723,  0.2320,
         0.5672, -0.8140, -0.4842,  0.3829, -1.4036, -0.3750, -2.0640, -0.9057,
        -1.1886,  0.3434, -1.0756, -1.4245,  1.1374, -0.1440, -0.1107, -2.4469,
         0.1129, -0.2940,  0.3541,  0.9514, -0.1509, -1.1097, -0.3776,  0.0645,
         0.1615, -0.3648,  0.8489, -0.1049,  0.1044, -0.9301,  0.1868,  0.8924,
         0.1700, -1.5468,  0.9586, -1.1084,  1.4576,  1.4288,  0.3229,  0.3504,
        -0.1556, -0.0749,  0.1157,  0.2287, -0.2752,  1.2659,  0.7747,  0.2845,
         0.5852, -0.9135,  1.0046,  1.1457, -0.8711,  0.5439,  0.4540,  0.0190,
        -0.2778,  1.8937, -1.7569, -1.3366, -0.5689, -1.9689,  0.2271, -0.3354],
       device='cuda:0', grad_fn=<MeanBackward1>)

#### compare with precomputed embeddings
>>> emb_train = torch.load('./data/pretrained/100.pt', map_location=device)
>>> cluster_center_model = get_cluster_center(emb_train, ec_id_dict_train)
>>> cluster_center_model['2.7.10.2']
tensor([ 0.5684, -0.2730,  1.3413, -0.0455,  0.5519, -0.5602,  0.4452,  0.3555,
        -0.3990,  0.8149, -0.7487,  0.8769, -0.0774, -1.2195, -0.3510,  0.3407,
         0.6935, -0.4897, -0.6786,  0.4822, -0.4402,  0.1503,  0.6215, -0.2650,
         1.0949,  0.4402,  0.4229,  1.4833,  0.2911, -2.0526, -1.0108,  0.8270,
         0.0103, -0.4963,  0.4265,  0.6308, -0.5499, -1.2763, -0.9737,  0.3145,
        -0.9146,  0.4415,  0.2394,  0.2096,  0.0948, -0.6719,  0.1269, -0.6432,
         1.3322,  0.8958,  0.2907,  1.5833,  1.6047,  0.0427, -0.1019, -0.1428,
         0.6814, -0.9868,  0.4500,  0.1787, -0.3415,  1.0227,  0.2722,  0.2320,
         0.5672, -0.8140, -0.4843,  0.3830, -1.4036, -0.3750, -2.0639, -0.9057,
        -1.1886,  0.3434, -1.0756, -1.4245,  1.1373, -0.1440, -0.1108, -2.4469,
         0.1129, -0.2940,  0.3541,  0.9514, -0.1508, -1.1097, -0.3776,  0.0645,
         0.1616, -0.3648,  0.8489, -0.1049,  0.1043, -0.9301,  0.1868,  0.8923,
         0.1699, -1.5468,  0.9585, -1.1083,  1.4576,  1.4288,  0.3229,  0.3504,
        -0.1556, -0.0749,  0.1157,  0.2288, -0.2752,  1.2659,  0.7747,  0.2845,
         0.5853, -0.9134,  1.0046,  1.1458, -0.8711,  0.5439,  0.4540,  0.0190,
        -0.2778,  1.8937, -1.7569, -1.3366, -0.5689, -1.9689,  0.2270, -0.3354])

Would greatly appreciate any help wherever you think I made a mistake.

Thank you!!

pgmikhael commented 9 months ago

Hi,

As an update to the above, I trained a model on split50 using the triplet loss and obtained reasonable performance (chose the triplet loss since it's faster). This seems to indicate that the trained model weights may be off. Please let me know.

tttianhao commented 9 months ago

Hi Peter, thanks for pointing out. We realized that we made a mistake in the splits csv files during previous commit. We have uploaded the original csv files for now. The issue should be resolved.