MattMcPartlon / AttnPacker

Code and Pre-Trained Models for "AttnPacker: An end-to-end deep learning method for protein side-chain packing"
78 stars 12 forks source link

how to generate results in the paper #5

Closed linbc-73 closed 1 year ago

linbc-73 commented 1 year ago

Hi, I came across a few problems with generating results mentioned in the paper.

  1. I can't find results of AttnPacker for several proteins -- T0955, T0980s2, T1016_A, T1019s1, T1062, T1073, from the data here https://zenodo.org/records/7713779/files/scp_results.zip. And I found these protein IDs mentioned in the supporting information M.
  2. I don't know how to generate the rmsd of surface/core residue and also χ1-4 accuracy in the paper. They are not included in the example code in readme.md(https://github.com/MattMcPartlon/AttnPacker#compare-side-chain-prediction-with-native-structure). And I couldn't find any related codes neither.

Please help me out! Thanks!

MattMcPartlon commented 1 year ago

Thanks for reaching out!

The residue level statistics are returned by the assess_sidechains function, (shown in the example you reference)

res_level_stats = assess_sidechains(target_pdb, predicted_pdb, steric_tol_fracs = [0.8])
print(pprint.pformat(res_level_stats))

the centrality tensor (part of residue_level_stats dict) is used to determine if a residue is part of the surface or core (defined in the paper)

\chi_1,...,\chi_4 accuracy is defined as the fraction of dihedral (\chi_i ) predictions with MAE less than 20 degrees (see res_level_stats['dihedral']['mae']. \chi_{1-4} accuracy is the fraction of residues with all chi dihedrals having MAE < 20 degrees. For example, you can compute the dihedral MAEs from per-residue stats as follows.

import math
mae_degrees = torch.abs(res_level_stats["dihedral"]["mae"])*(180/math.pi)
mask = res_level_stats["dihedral"]["mask"]
num_dihedrals = torch.clamp_min(torch.sum(mask,dim=0),1)

dihedrals_less_than_20 = torch.logical_and(mae_degrees <= 20, mask)
# chi_{i} accuracy
print("Chi_{1}...Chi_{4}")
print(torch.sum(dihedrals_less_than_20, dim=0)/num_dihedrals)
# chi_{1-4} accuracy
has_dihedrals = torch.any(mask,dim=-1)
all_less_than_20 = torch.all(dihedrals_less_than_20==mask,dim=-1)
all_less_than_20 = torch.logical_and(has_dihedrals,all_less_than_20)
print("Chi_{1-4}")
print(torch.sum(all_less_than_20)/torch.clamp_min(torch.sum(has_dihedrals),1))
Chi_{1}...Chi_{4}
tensor([0.653, 0.614, 0.167, 0.286])
Chi_{1-4}
tensor(0.520)

As for the missing pdb files, these should have been removed from the supplement because they were all too large (We ran out of local memory running AF2 and other methods on these targets). Feel free to run inference on them and include them in the analysis. You can use chunked-inference with AttnPacker if memory is an issue.

linbc-73 commented 1 year ago

Thanks for your help!

Yet I still can't reproduce the results in the paper. The data that I used were pdb files given in here, https://zenodo.org/records/7713779/files/scp_results.zip, altogether 78 pdbs for CASP13 regular and 62 for CASP14 regular. As mentioned above, I neglect the non-existed proteins including T0955, T0980s2, T1016_A, T1019s1, T1062, T1073. The results I get were:

CASP13 regular Native for AttnPacker_PP all rmsd core rmsd sfc rmsd chi-1 mae chi-2 mae chi-3 mae chi-4 mae all chi acc core chi acc sfc chi acc steric clash 100% steric clash 90% steric clash 80%
0.651 0.324 0.702 18.038 24.142 46.132 58.052 0.615 0.785 0.583 88.5 23.6 8.4
CASP14 regular Native AttnPacker_PP all rmsd core rmsd sfc rmsd chi-1 mae chi-2 mae chi-3 mae chi-4 mae all chi acc core chi acc sfc chi acc steric clash 100% steric clash 90% steric clash 80%
0.810 0.381 0.881 24.275 28.947 48.576 54.628 0.524 0.723 0.494 88.1 25.5 10.6
CASP13 regular Native for AttnPacker_no_pp all rmsd core rmsd sfc rmsd chi-1 mae chi-2 mae chi-3 mae chi-4 mae all chi acc core chi acc sfc chi acc steric clash 100% steric clash 90% steric clash 80%
0.608 0.306 0.652 16.841 22.793 43.823 53.283 0.627 0.800 0.593 74.9 16.6 5.3
CASP14 regular Native AttnPacker_no_pp all rmsd core rmsd sfc rmsd chi-1 mae chi-2 mae chi-3 mae chi-4 mae all chi acc core chi acc sfc chi acc steric clash 100% steric clash 90% steric clash 80%
0.779 0.361 0.849 23.588 28.264 47.669 50.688 0.522 0.728 0.493 75.4 19.5 6.6

The results for steric clash exhibited a significant discrepancy compared to those in the paper.

res_level_stats = assess_sidechains(target_pdb, predicted_pdb, steric_tol_fracs = [1,0.9,0.8])
target_level_stats = summarize(res_level_stats)

I directly used the average of target_level_stats["clash_info"]["100/90/80"]["num_clashes"] to generate my results. I am not sure where it went wrong.

MattMcPartlon commented 1 year ago

Well at least the RMSD results you're getting better results than I reported 😅 . Are you including RMSD of GLY or ALA? I computed the average without accounting for these residues.

There are two ways to compute the metrics from the paper: You can average over all residues in a test set, or you can average over all targets. It is most common to average over all residues in the test set, and I believe this is the definition I used in the paper (I wrote this 2 years ago so my memory is a bit fuzzy). i.e. If you have two targets A and B, and there are 100 \chi_1 dihedrals in A, and 200 \chi_1 dihedrals in B, then the mean chi_1 error would be

100/300 (mean chi_1-error in A) + (200/300) (mean chi_1-error in B) , i.e. it is weighted by the number of residues with the given dihedral in each chain.

The steric clash results reported in the paper used different code than what is now in the code base. The paper results did not sufficiently penalize clashes for certain side chain atoms. The clash computation in the code base now is correct and should be used for all future work. I apologize for this. What happened was that the steric clash calculation was not penalizing atoms in residues with side chain symmetries (e.g. Tyr, Arg), and because of this, the number of clashes should be recomputed for future work. In addition, the post-processing of attnpacker pdbs should also be recomputed with the updated clash loss. To do this, you can just run the post-process function on the AttnPacker_no_pp pdbs.

linbc-73 commented 1 year ago

Thanks for the instructions!