google-deepmind / alphafold

Open source code for AlphaFold.
Apache License 2.0
12.1k stars 2.16k forks source link

The definition of bins in Predicted Aligned Error Head(PAE) may be wrong #929

Open xiergo opened 2 months ago

xiergo commented 2 months ago

I am confused about the definition of bins in Predicted Aligned Error Head(PAE). The breaks is defined as [0, 0.5, 1, ..., 31]

# self.config.max_error_bin=31,  self.config.num_bins=64
  breaks = jnp.linspace(
        0., self.config.max_error_bin, self.config.num_bins - 1)

and the centers are [0.25, 0.75, ..., 30.5, 31.5, 32.5], according to:

def _calculate_bin_centers(breaks: np.ndarray):
  """Gets the bin centers from the bin edges.

    breaks: [num_bins - 1] the error bin edges.

    bin_centers: [num_bins] the error bin centers.
  step = (breaks[1] - breaks[0])

  # Add half-step to get the center
  bin_centers = breaks + step / 2
  # Add a catch-all bin at the end.
  bin_centers = np.concatenate([bin_centers, [bin_centers[-1] + step]],
  return bin_centers

Then the 64 bins are [0, 0.5], [0.5, 1] ..., [31, 31.5], [31.5, +inf].

But the bins defined in the PAE-loss are [-inf, 0], [0, 0.5], ...[31, +inf], which are left shifted for one bin, based on the definition in alphafold/alphafold/model/ line 1200:

sq_breaks = jnp.square(breaks) #[0, 0.5, ..., 31]
    true_bins = jnp.sum((
        error_dist2[..., None] > sq_breaks).astype(jnp.int32), axis=-1)

    errors = softmax_cross_entropy(
        labels=jax.nn.one_hot(true_bins, num_bins, axis=-1), logits=logits)

For example, for error_dist=0.75, which should fall into the second bin [0.5, 1], but (0.75>breaks).sum() is 2, the one_hot values are [0, 0, 1, 0, ..., 0] with the third entry being 1, which is incorrect.