sokrypton / ColabDesign

Making Protein Design accessible to all via Google Colab!
549 stars 127 forks source link

difference between plddt in the log and in the saved PDB! #102

Closed shahmandi closed 1 year ago

shahmandi commented 1 year ago

I noticed that there is a discrepancy between the plddt values in the log (i.e. af_model.aux["log"]["plddt"]) and the one obtained by averaging the b-factor of the saved PDB file (where the PDB string is obtained by af_model.save_pdb(get_best=False)). The plddt in the log is usually larger.

This is the case even when num_models=1 and num_recycles=0. Is this due to slightly different normalisation? or am I missing something?

ps, af_model = mk_af_model(use_multimer=True, use_temaplete=True, best_metric="dgram_cce")

sokrypton commented 1 year ago

Thanks! There is a slight difference in how the plddt is computed for the loss function vs. output. But they should be highly correlated.

plddt loss:

def get_plddt_loss(outputs, mask_1d=None):
  p = jax.nn.softmax(outputs["predicted_lddt"]["logits"])
  p = (p * jnp.arange(p.shape[-1])[::-1]).mean(-1)
  return mask_loss(p, mask_1d)

plddt output

def get_plddt(outputs):
  logits = outputs["predicted_lddt"]["logits"]
  num_bins = logits.shape[-1]
  bin_width = 1.0 / num_bins
  bin_centers = jnp.arange(start=0.5 * bin_width, stop=1.0, step=bin_width)
  probs = jax.nn.softmax(logits, axis=-1)
  return jnp.sum(probs * bin_centers[None, :], axis=-1)

should be an easy fix to make sure they are identical!

sokrypton commented 1 year ago

I've updated the code. pLDDT should now be consistent across outputs!