evolutionaryscale / esm

Other
1.31k stars 154 forks source link

Cannot reproduce results on CAMEO (194 targets) and CASP14 (71 targets) #63

Open zhujianwei31415 opened 4 months ago

zhujianwei31415 commented 4 months ago

Hi Team,

Thank you for the excellent work on ESM3.

I am currently comparing the performance of ESM3 on the CAMEO and CASP14 datasets with ESMFold and AlphaFold2. The CAMEO dataset I am using is the same as for ESMFold. The codes I am using are listed below. I have calculated the "TM-score" and "LDDT". However, I have noticed that the LDDT scores do not match the results published in the ESM3 paper.

Could you please share the evaluation parameters and scripts used to reproduce the results in the ESM3 paper?

Thank you for your assistance.

for outpdb, sequence in zip(pdbs, seqs): 
    # Folding
    protein = ESMProtein(sequence=sequence)
    sequence_length = len(protein.sequence)  # type: ignore
    # num_steps = sequence_length #Iterative
    # num_steps = int(sequence_length / 16)
    num_steps = 1 # Argmax
    protein.coordinates = None
    protein.function_annotations = None
    protein.sasa = None
    folded_protein = client.generate(
        protein,
        GenerationConfig(track="structure", schedule="cosine", num_steps=num_steps),
    )
    assert isinstance(folded_protein, ESMProtein)
    folded_protein.to_pdb(outpdb)
Screenshot 2024-07-25 151359
santiag0m commented 4 months ago

We released a recent update improving iterative decoding for structure prediction: https://github.com/evolutionaryscale/esm/releases/tag/v3.0.0.post2

Make sure you are using the latest version. Also, it is important that you set temperature to 0 to achieve best results, as it is an important source of variability.

ebetica commented 3 months ago

The set of CAMEO targets we use is the set released from May 1, 2020 through Aug 1, 2023. Though I think if you use the updated version of esm3 you'll find a lot better results. We currently have not released the all atom structure decoder used to benchmark our models with either (will be available via the API initially).

zhujianwei31415 commented 3 months ago

Thanks @santiag0m and @ebetica.

I have used the same CAMEO dataset and the latest v3.0.0.post2.

For "We currently have not released the all atom structure decoder used to benchmark our models with either (will be available via the API initially)." You means the model you released on hugging face isn't the model you used to evaluation CAMEO, CAP14 and CASP15 on the paper? The results on the paper predicted from the "all atom structure decoder", right?

ebetica commented 3 months ago

Yes. However, I would say that you should be able to get to 1-3 lddt within the paper results using the open source model and structure decoder. The large LDDT gap you have indicates something is wrong. The fact the iterative and argmax of the local run results in the same numbers imply to me that it was measured with the bugged release...

zhujianwei31415 commented 3 months ago

Thanks @ebetica.

Yes, I agree with you. I have tested two versions, 95e3c5b and f342784, and the results are very similar. The generation code I used was provided earlier. I used the default temperature instead of 0.0, but I believe temperature might not be the main factor.

Could you use the released code to perform the evaluation and share the parameters and scripts with me?

zhujianwei31415 commented 3 months ago

Additionally, I would like to clarify a point. For Argmax decoding, does setting num_steps=1 achieve this? And for Iterative decoding, should set num_steps=sequence_length?

ebetica commented 3 months ago

Here's the way we do this internally:

pc = ProteinChain(...)
id = "..."

gen = client.generate(
    ESMProtein(sequence=pc.sequence),
    GenerationConfig(
        track="structure",
        schedule="cosine",
        num_steps=1 if argmax else len(pc),
        temperature=0,
    ),
)
assert isinstance(gen, ESMProtein)  # Assert we didn't error out...
assert gen.coordinates is not None
lddt = pc.lddt_ca(gen.to_protein_chain(), per_residue=False)
results.append([id, lddt])
zhujianwei31415 commented 3 months ago

@ebetica Thank you very much!

I will generate the structures using this method and re-evaluate the results.

Additionally, could you please share the target lists for CAMEO, CASP14, and CASP15? I would like to double-check the consistency of the targets.

zhujianwei31415 commented 3 months ago

@ebetica @santiag0m Thanks for your help!

I used the latest code, esm-3.0.2, along with the generation code you provided. I successfully reproduced the results for the Argmax decoder (num_steps=1) and Iterative decoder num_steps=sequence_length) . And it is very important to set temperature=0 to achieve best result.

results = []
for target in targets:
    inppdb = f"{inpdir}/{target}.pdb"
    outpdb = f"{outdir}/{target}.pdb"

    # Folding
    protein = ProteinChain.from_pdb(inppdb)
    gen = client.generate(
        ESMProtein(sequence=protein.sequence),
        GenerationConfig(
            track="structure",
            schedule="cosine",
            num_steps=1 if argmax else len(protein.sequence),
            temperature=0,
        ),
    )
    assert isinstance(gen, ESMProtein)  # Assert we didn't error out...
    assert gen.coordinates is not None
    gen.to_pdb(outpdb)

    # Calculate LDDT
    lddt = protein.lddt_ca(gen.to_protein_chain(), per_residue=False)
    results.append([target, lddt])
Screenshot 2024-08-11 170108

However, when I evaluate the generated "*.pdb" files, the results from the Iterative decoder appear to be different (CAMEO: 80.50, CASP14: 59.95), resembling those of the Argmax decoder. I will double check the evaluation scripts for calculating "LDDT" and the generated structures.

Could you please share the target lists for CAMEO, CASP14, and CASP15? I would like to double-check the consistency of the targets?

zhujianwei31415 commented 3 months ago

@santiag0m @ebetica Thanks for your help!

I used the latest code, esm-3.0.2, along with the structure generation code you provided. I successfully generated protein structures for three datasets (CAMEO, CASP14, and CASP15) using both the Argmax and Iterative decoders.

During the evaluation of the generated structures, I encountered some issues. The performance of the Argmax decoder closely matches the results reported in the paper. However, there is a significant discrepancy with the Iterative decoder's performance. In my local evaluation, the results from the Iterative decoder are nearly identical to those of the Argmax decoder, with very similar TM-score and LDDT values for both algorithms. These results differ significantly from those published in the paper, prompting me to recheck the LDDT calculation using different methods.

Please note that the generated structures were identical across all methods and were preprocessed to align with the chain and residue numbers of the native PDB files. Given that LDDT calculation is asymmetric, the choice of reference structure is crucial. I evaluated the structures using three different methods:

I found that the LDDT values calculated with ESM3_LDDT_SYMMETRY matched those of the precompiled program, but they differed from the scores reported in the ESM3 paper. Interestingly, the LDDT values from ESM3_LDDT closely matched those in the paper.

After reviewing the source code for lddt_ca and compute_lddt_ca, I observed that the first parameter in compute_lddt_ca represents the predicted positions, while the second parameter is the ground truth. The lddt_ca function suggests that the target protein chain should be the native structure, which indicates that ESM3_LDDT_SYMMETRY is likely the correct evaluation method, while ESM3_LDDT might not be accurate, despite yielding results similar to those in the ESM3 paper.

def lddt_ca(
        self,
        target: ProteinChain,
        mobile_inds: list[int] | np.ndarray | None = None,
        target_inds: list[int] | np.ndarray | None = None,
        **kwargs,
    ) -> float | np.ndarray:
        """Compute the LDDT between this protein chain and another.

        Arguments:
            target (ProteinChain): The other protein chain to compare to.
            mobile_inds (list[int], np.ndarray, optional): The indices of the mobile atoms to align. These are NOT residue indices
            target_inds (list[int], np.ndarray, optional): The indices of the target atoms to align. These are NOT residue indices

        Returns:
            float | np.ndarray: The LDDT score between the two protein chains, either
                a single float or per-residue LDDT scores if `per_residue` is True.
        """

        lddt = compute_lddt_ca(
            torch.tensor(self.atom37_positions[mobile_inds]).unsqueeze(0),
            torch.tensor(target.atom37_positions[target_inds]).unsqueeze(0),
            torch.tensor(self.atom37_mask[mobile_inds]).unsqueeze(0),
            **kwargs,
        )
        return float(lddt) if lddt.numel() == 1 else lddt.numpy().flatten()

def compute_lddt_ca(
    all_atom_pred_pos: torch.Tensor,
    all_atom_positions: torch.Tensor,
    all_atom_mask: torch.Tensor,
    cutoff: float = 15.0,
    eps: float = 1e-10,
    per_residue: bool = True,
) -> torch.Tensor:
    ca_pos = RC.atom_order["CA"]
    if all_atom_pred_pos.dim() != 3:
        all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
    all_atom_positions = all_atom_positions[..., ca_pos, :]
    all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)]  # keep dim

    return compute_lddt(
        all_atom_pred_pos,
        all_atom_positions,
        all_atom_mask,
        cutoff=cutoff,
        eps=eps,
        per_residue=per_residue,
    )

def compute_lddt(
    all_atom_pred_pos: torch.Tensor,
    all_atom_positions: torch.Tensor,
    all_atom_mask: torch.Tensor,
    cutoff: float = 15.0,
    eps: float = 1e-10,
    per_residue: bool = True,
) -> torch.Tensor:
    """
    Computes LDDT for a protein. Tensor sizes below include some optional dimensions. Specifically:
        Nstates:
            all_atom_pred_pos can contain multiple states in the first dimension which corresponds to outputs from different layers of a model (e.g. each IPA block). The return size will be [Nstates x Batch size] if this is included.
        Natoms:
            LDDT can be computed for all atoms or some atoms. The second to last dimension should contain the *FLATTENED* representation of L x Natoms. If you want to calculate for atom37, e.g., this will be of size (L * 37). If you are only calculating CA LDDT, it will be of size L.

    Args:
        all_atom_pred_pos (Tensor[float], [(Nstates x) B x (L * Natoms x) 3]): Tensor of predicted positions
        all_atom_positions (Tensor[float], [B x (L * Natoms x) 3]): Tensor of true positions
        all_atom_mask (Tensor[float], [B x (L * Natoms)]): Tensor of masks, indicating whether an atom exists.
        cutoff (float): Max distance to score lddt over.
        per_residue (bool): Whether to return per-residue or full-protein lddt.

    Returns:
        LDDT Tensor:
            if per_residue:
                Tensor[float], [(Nstates x) B x (L * Natoms)]
            else:
                Tensor[float], [(Nstates x) B]
    """

Based on the TM-score and the correct LDDT results, it appears that the Iterative decoder performs similarly to the Argmax decoder. Could you please clarify my evaluation pipeline and the results I’ve obtained? Could you please share the target lists for CAMEO, CASP14, and CASP15 for checking the consistency?

image
heya5 commented 1 month ago

@zhujianwei31415 @ebetica Hi, where can I find the CASP14 71 targets IDs you used? I didnt find these IDs in ESM-3 papers. Thanks!

ebetica commented 1 month ago

@zhujianwei31415 just to follow up here - you were correct, and our numbers are a little lower than we first claimed. Iterative decoding for structure prediction doesn't seem to help. We're going to update our numbers and release a new paper version soon. Thanks so much for helping to discover this!

ebetica commented 1 month ago

@heya5 Please check this list: https://predictioncenter.org/casp14/targetlist.cgi

It's not comprehensive, some of the unreleased targets we got directly from the organizers.