mheinzinger / ProstT5

Bilingual Language Model for Protein Sequence and Structure
MIT License
147 stars 13 forks source link

Probability cutoff? #24

Open mauriciolangleib opened 4 days ago

mauriciolangleib commented 4 days ago

Hello,

First of all again, many thanks for developing this great tool. I open this issue because I wanted to know which probability cutoff is recomended, in your opinion, as a minimum to trust a 3Diaa prediction.

I haven't found such a rule of thumb in the documentation, and neither a clear definition of what is this probability cutoff (which, as far as I understand, is related to the NNs prediction -is this correct?-).

Many thanks in advance and sorry to bother with this.

Cheers, Mauricio

mheinzinger commented 3 days ago

Hi, thanks again for your interest in our tools! :) I asssume you are referring to the coverage/precision plot that we added to the most recent version of the paper (Fig. SOM S5 - also attached here). You are right: the probability that this plot is referring to is the output probability of the CNN put on top of ProstT5 embeddings to predict 3Di directly from AA-input. So it just takes the raw output (logits) of this network and uses them as a proxy for reliability. I am always a bit hesitant to give precise thresholds for those because depending on your specific use case this threshold might differ. So in case you have a use case where you rather want to have more hits at the cost of accepting some more false-positives, you might accept a bit lower cutoff than if your setup is vice versa (rather less FPs but then also less coverage). As a general starting point, you might want to start filtering out sequences that have on average a probability of <=0.4. That being said: I think @gbouras13 looked into this in more detail for phages using a cutoff of 0.5. Maybe he can give some more input? Screenshot from 2024-06-26 11-04-35

rakeshr10 commented 2 days ago

Hi @mheinzinger, In the output_probabilities.csv file there are fasta headers and numbers ranging from 0 to 100. How is this calculated is it the same as the one you described above.

How does the model handle non-standard characters such as 'X' present in sequence files and how does it effect the probability values. Also how does the size of the proteins affect the probability.

I have also noticed on some fasta files I cannot generate output_probabilities.csv files as it runs into this error.

/opt/conda/lib/python3.10/site-packages/numpy/_core/fromnumeric.py:3596: RuntimeWarning: Mean of empty slice.
  return _methods._mean(a, axis=axis, dtype=dtype,
/opt/conda/lib/python3.10/site-packages/numpy/_core/_methods.py:136: RuntimeWarning: invalid value encountered in divide
  ret = arr.dtype.type(ret / rcount)
Traceback (most recent call last):
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 389, in <module>
    main()
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 377, in main
    get_embeddings(
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 287, in get_embeddings
    prob = int( 100* np.mean(probabilities[batch_idx, :, 0:s_len]))
ValueError: cannot convert float NaN to integer
mheinzinger commented 2 days ago

Good point - sorry for the confusion. Numbers in the CSV are simply multiplied by 100 but it refers to the same probabilities (always taking the softmax of the logits output by the CNN which predicts 3Di as defined here: https://github.com/mheinzinger/ProstT5/blob/main/scripts/predict_3Di_encoderOnly.py#L274 ). X is handled as wildcard for unknown/non-standard amino acids. How their presence exactly affects the probability is not sth that I investigated in detail because those cases are rather rare. That being said: given that the probability is an average over the full sequence and it is rather rare that sequences have >5% missing/non-standard AAs, I doubt that having a small number of "X" in your input changes the predicted probability much. As long as your protein is below 400 residues (close to our upper length during training), I do not think that the predictions are affected much. If you go beyond 400, performance will decrease slightly but I did not quantify by how much (as long as you stay below e.g. 1k residues, I do not think that it should hurt you much). Looking at some isolated examples indicated that the approach "Predict 3Di via ProstT5 and search with Foldseek against some DB" still works fine. If you intend to use the predicted 3Di for sth different, the situation might change.

I think most importantly for all your considerations above is: if your use case is the one outlined above (using ProstT5 3Di prediction for foldseek input), the approach still works remarkably well even if 3Di predictions might not be perfect. The point is that mistakes usually happen between tokens that are coding for similar structural conformations (just imagine 3 tokens coding for alpha-helices and our predictor fails at distinguishing them). As a consequence, Foldseek still gives the same/very similar hits even if the predicted 3Di is not perfectly aligning with the groundtruth.

rakeshr10 commented 1 day ago

Can you also comment on this error? I get this when I use it on my fasta file. I cannot make predictions for many files and get the output probabilities. If I switch off writing the probabilities I can output 3Di fasta files for these fasta files. Is it because the model is not able to predict the probabilities for some sequences?

/opt/conda/lib/python3.10/site-packages/numpy/_core/fromnumeric.py:3596: RuntimeWarning: Mean of empty slice.
  return _methods._mean(a, axis=axis, dtype=dtype,
/opt/conda/lib/python3.10/site-packages/numpy/_core/_methods.py:136: RuntimeWarning: invalid value encountered in divide
  ret = arr.dtype.type(ret / rcount)
Traceback (most recent call last):
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 389, in <module>
    main()
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 377, in main
    get_embeddings(
  File "/app/ProstT5/scripts/predict_3Di_encoderOnly.py", line 287, in get_embeddings
    prob = int( 100* np.mean(probabilities[batch_idx, :, 0:s_len]))
ValueError: cannot convert float NaN to integer

We have 'X' s in many of our sequences due to sequencing errors and it will be greater than 5% too, so I will need to see how this could be handled. The other question is what are the other wild-card entries that the model can handle. For ex: In our sequences we tend to add to the last aa. Does the need to be removed before giving it to the model?