facebookresearch / esm

Evolutionary Scale Modeling (esm): Pretrained language models for proteins
MIT License
3.16k stars 627 forks source link

About trRosetta data used in training regression head #131

Closed YijiaXiao closed 3 years ago

YijiaXiao commented 3 years ago

Hi, I also used the MSA model you released for contact prediction. In your ICLR paper, you mentioned that you used 20 MSA samples for training contact regression head. I also tried using 20 samples for training regression head (sklearn's LogisticRegression model, L1 penalty=0.15). However, I found that the performance of trained regression heads varies (±5%, absolute value), so I wonder how to choose the most feasible 20 samples for training the contact prediction' head. And I would appreciate it if you could provide the 20 training, 20 validation MSA samples' ids.

Thank you!

tomsercu commented 3 years ago

Interesting observation with MSA Transformer. In the paper you link, appendix Fig12 you can see that usually the performance is extremely robust against the choice of proteins for training logistic regression. At the other hand with MSA Transformer there is the additional factor of MSA quality: if you choose a protein with a bad (shallow) MSA your representations will suffer.

tomsercu commented 3 years ago

Could you specify about the ±5% absolute value

liujas000 commented 3 years ago

We used the following 20 proteins for training. The remaining ~14K structures are used as the test set, and the scores are reported in table 1 of the paper you linked!

trRosetta/2flh_1_A trRosetta/4lqk_1_A trRosetta/3fka_1_B trRosetta/5wxk_1_B trRosetta/2jgp_1_A trRosetta/4xdd_1_A trRosetta/4gaf_1_B trRosetta/1np6_1_A trRosetta/3iag_1_A trRosetta/3oa8_1_B trRosetta/2c2i_1_B trRosetta/2ahr_1_E trRosetta/2y3c_1_A trRosetta/1urq_1_C trRosetta/1rku_1_A trRosetta/2qww_1_B trRosetta/2q3q_1_A trRosetta/3dou_1_A trRosetta/1qua_1_A trRosetta/4ge3_1_A

YijiaXiao commented 3 years ago

Hi @tomsercu , thank you for your timely reply and advice: (a) "if you choose a protein with a bad (shallow) MSA your representations will suffer" yeah, I do observe a minor increase in performance when I increase depth from 128 to 256. (b) "specify about the ±5% absolute value" metric: I use LogisticRegression from sklearn, and use the predict_proba function, to get the probability of contact(1) or no-contact(0), then I will sort according to the probability of contact, and get the top L/5 prediction (which means their predicted '1' probability is highest). Then I will calculate the accuracy of these L/5 predictions. test set: I tested on CAMEO dataset. (My metric code are provided as follow)

And thank you @liujas000 for providing the protein ids, I will try to train on these 20 data points.

Besides, I found that the regression model I trained, is not as sparse as the one you released, so there may be some problem in my training, I will check it out.

Thank you all again for your reply:)

--- metric code

def calculate_contact_precision(name, pred, label, local_range, local_frac=5, ignore_index=-1):
    """
        local_range: eg. local_range=[12, 24], calculate midium range contacts
        local_frac: eg. local_frac=5, calculate P@L/5, local_frac=2, calculate P@L/2
    """
    for i in range(len(label)):
        for j in range(len(label)):
            if (abs(i - j) < local_range[0] or abs(i - j) >= local_range[1]):
                label[i][j] = ignore_index

    correct = 0
    total = 0

    predictions = pred
    labels = label.reshape(-1)

    valid_masks = (labels != ignore_index)
    confidence = predictions[:, 1]
    valid_masks = valid_masks.type_as(confidence)
    masked_prob = (confidence * valid_masks).view(-1)
    seq_len = int(len(labels) ** 0.5)
    most_likely = masked_prob.topk(seq_len // local_frac, sorted=False)
    selected = labels.view(-1).gather(0, most_likely.indices)
    selected[selected < 0] = 0
    correct += selected.sum().long()
    total += selected.numel()
    return correct, total
liujas000 commented 3 years ago

One small comment: I'd recommend calculating precision over the upper triangle of the pred/label. That way you won't double count symmetric predictions :)

YijiaXiao commented 2 years ago

Hi, I also have a question regarding the calculation of the contact map. In your ICLR paper, section 4.3 TRANSFORMERS, you mentioned that In this work, we demonstrate that the QKT pairwise “self-attention maps” indeed capture accurate contacts; and in your MSA Transformer paper section 4.1. Unsupervised Contact Prediction, you mentioned that "We follow the methodology of Rao et al. ..." to make contact predictions, so it seems that the methods used in these 2 papers are the same (QK^T). I this repo, it seems that the row attention used for calculating contact is row attention probability, and used reduction of attention maps.

So I am wondering which one should I use for contact prediction (and their possible difference)? Thank you!

tomsercu commented 2 years ago

Yes the tied row attention is used precisely as the single sequence self-attention maps. You can find the way those attention maps are used in ContactPredictionHead, specifically symmetrization and apc (average product correction) here