lucidrains / alphafold3-pytorch

Implementation of Alphafold 3 in Pytorch
MIT License
1.19k stars 140 forks source link

Incorrect tracking of each token's center atom? #81

Closed amorehead closed 3 months ago

amorehead commented 3 months ago

In the following line of code, it seems that the index of the atom in each residue that is to be used for distogram prediction is also being used as the representative (i.e., token) center atom of each residue. However, in the AF3 supplement, there is a difference between these representative atoms for distogram prediction and for the rest of the architecture. Specifically, the AF3 supplement uses the following atoms for distogram prediction: Cβ for protein residues (Cα for glycine), C4 for purines and C2 for pyrimidines.

However, for general tokenization of atoms in the rest of the architecture, the AF3 supplement provides the following description of each residue's (token) center atoms: For each token we also designate a token centre atom, used in various places below: • Cα for standard amino acids • C1′ for standard nucleotides • For other cases take the first and only atom as they are tokenized per-atom.

https://github.com/lucidrains/alphafold3-pytorch/blob/84c73a1b782010ae43f1c543c77c87576acb2124/alphafold3_pytorch/inputs.py#L551

lucidrains commented 3 months ago

yeah I noticed that, but don't know how or where the "token centre atom" is used

lucidrains commented 3 months ago

if you can find and screenshot the section of the paper where they use this rep atom token, I'll get it in there

amorehead commented 3 months ago

For distogram prediction (and only for distogram prediction I believe), here's the set of atoms they used for each type of residue: image

amorehead commented 3 months ago

Otherwise, I believe AF3 uses the following atoms as token center atoms in the rest of the codebase. For example, these are also the token center atoms the PDB filtering script uses: image

lucidrains commented 3 months ago

yea sounds good, I can just start keeping track of distogram vs center atom separately, but afaict, the information for center atom is not used anywhere in the arch, so just renaming things as the distogram atom would suffice

amorehead commented 3 months ago

I see

lucidrains commented 3 months ago

could be wrong!

amorehead commented 3 months ago

Yeah, they might have just used these center token atom definitions for filtering. Perhaps someone else can chime in with a different take, in which case it might make sense to leave this issue open for the time being

lucidrains commented 3 months ago

@amorehead i'll pass it into Alphafold3 just in case we missed something

yes, agreed with keeping the issue open for visibility

lucidrains commented 3 months ago

@amorehead could the token center atoms be used here? (plucking out the atom positions being used for the confidence head)

lucidrains commented 3 months ago

@gitabtion maybe Heng has an opinion too, as i noticed he did a thumbsup emoji (and probably doing some training already)

amorehead commented 3 months ago

@lucidrains, perhaps. But on second thought, I can say from my experience using the AF3 web server that (compared to version 2 of AlphaFold) it now produces per-atom confidence (e.g., plDDT) values, instead of per-token/per-residue values. It looks like in the confidence head code you linked that only the token center atoms are being used here.

lucidrains commented 3 months ago

@amorehead oh i see, is it only plDDT? or everything?

lucidrains commented 3 months ago

yes, i assumed it was residue / token level, given it accepts single and pairwise rep

amorehead commented 3 months ago

Yeah, in the following AF3 server mmCIF file coordinates, you can see how the values in the 4th-to-last column (in range [0, 100.0]) vary per atom instead of per residue:

#
loop_
_atom_site.group_PDB
_atom_site.id
_atom_site.type_symbol
_atom_site.label_atom_id
_atom_site.label_alt_id
_atom_site.label_comp_id
_atom_site.label_asym_id
_atom_site.label_entity_id
_atom_site.label_seq_id
_atom_site.pdbx_PDB_ins_code
_atom_site.Cartn_x
_atom_site.Cartn_y
_atom_site.Cartn_z
_atom_site.occupancy
_atom_site.B_iso_or_equiv
_atom_site.auth_seq_id
_atom_site.auth_asym_id
_atom_site.pdbx_PDB_model_num
ATOM   1    N  N     . MET A 1 1   ? 2.673   -27.151 -5.777  1.00 39.54 1   A 1 
ATOM   2    C  CA    . MET A 1 1   ? 3.790   -26.285 -6.177  1.00 45.41 1   A 1 
ATOM   3    C  C     . MET A 1 1   ? 4.226   -26.537 -7.603  1.00 48.65 1   A 1 
ATOM   4    O  O     . MET A 1 1   ? 4.130   -27.665 -8.091  1.00 48.01 1   A 1 
ATOM   5    C  CB    . MET A 1 1   ? 4.982   -26.497 -5.234  1.00 44.59 1   A 1 
ATOM   6    C  CG    . MET A 1 1   ? 4.790   -25.949 -3.852  1.00 41.20 1   A 1 
ATOM   7    S  SD    . MET A 1 1   ? 6.006   -26.553 -2.678  1.00 35.10 1   A 1 
ATOM   8    C  CE    . MET A 1 1   ? 5.416   -28.225 -2.406  1.00 30.79 1   A 1 
ATOM   9    N  N     . ALA A 1 2   ? 4.677   -25.507 -8.270  1.00 42.67 2   A 1 
ATOM   10   C  CA    . ALA A 1 2   ? 5.152   -25.595 -9.643  1.00 44.96 2   A 1 
ATOM   11   C  C     . ALA A 1 2   ? 6.178   -24.501 -9.904  1.00 47.97 2   A 1 
ATOM   12   O  O     . ALA A 1 2   ? 6.376   -23.609 -9.071  1.00 45.38 2   A 1 
ATOM   13   C  CB    . ALA A 1 2   ? 3.993   -25.487 -10.617 1.00 42.21 2   A 1 
ATOM   14   N  N     . SER A 1 3   ? 6.828   -24.573 -11.064 1.00 43.81 3   A 1 
ATOM   15   C  CA    . SER A 1 3   ? 7.796   -23.549 -11.428 1.00 48.64 3   A 1 
ATOM   16   C  C     . SER A 1 3   ? 7.085   -22.208 -11.611 1.00 51.66 3   A 1 
ATOM   17   O  O     . SER A 1 3   ? 5.988   -22.143 -12.171 1.00 48.75 3   A 1 
ATOM   18   C  CB    . SER A 1 3   ? 8.513   -23.935 -12.716 1.00 44.42 3   A 1 
ATOM   19   O  OG    . SER A 1 3   ? 7.599   -24.072 -13.780 1.00 39.58 3   A 1 
ATOM   20   N  N     . SER A 1 4   ? 7.720   -21.146 -11.141 1.00 60.43 4   A 1 
ATOM   21   C  CA    . SER A 1 4   ? 7.145   -19.814 -11.181 1.00 62.17 4   A 1 
ATOM   22   C  C     . SER A 1 4   ? 7.623   -19.084 -12.426 1.00 63.89 4   A 1 
ATOM   23   O  O     . SER A 1 4   ? 8.803   -18.734 -12.532 1.00 59.08 4   A 1 
ATOM   24   C  CB    . SER A 1 4   ? 7.507   -19.032 -9.924  1.00 56.36 4   A 1 
ATOM   25   O  OG    . SER A 1 4   ? 6.834   -17.791 -9.881  1.00 51.09 4   A 1 
ATOM   26   N  N     . ARG A 1 5   ? 6.721   -18.846 -13.370 1.00 71.89 5   A 1 
ATOM   27   C  CA    . ARG A 1 5   ? 7.048   -18.153 -14.599 1.00 76.43 5   A 1 
ATOM   28   C  C     . ARG A 1 5   ? 6.761   -16.665 -14.462 1.00 79.44 5   A 1 
lucidrains commented 3 months ago

@amorehead that's very strange because the confidence head contains a Pairformer stack, and that will blow up (attention-wise) if done atom resolution

lucidrains commented 3 months ago

they could expand the output of that into atom resolution, append some per atom / residue embedding before projecting into logits though. that's one possibility

lucidrains commented 3 months ago

maybe i'm missing something

amorehead commented 3 months ago

I suspect that's what they do, attention on tokens then appending atom embeddings before projection. I don't know how else they would efficiently predict per-atom confidence

lucidrains commented 3 months ago

they could have made the confidence head have an atom decoder as well, and left that out of the paper

lucidrains commented 3 months ago

let me think about it

lucidrains commented 3 months ago

this is technically addressed, the atom resolution confidence head we can open a new issue for