google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

ntxent fix #946

Open GrantMcConachie opened 2 months ago

GrantMcConachie commented 2 months ago

I added an epsilon value to the cosine similarity function to avoid the NaNs that were occurring when when you had a label vector [0, 0, 0] or when one of your embeddings was the 0 vector.

vroulet commented 2 months ago

Thanks again @GrantMcConachie ! Quick question: do you think it could be handled by a jnp.where? Imagine for example that you want to normalize a vector and handle the case with a zero vector properly. You may then do

import jax.numpy as jnp
def normalize_vec(x):
  norm = jnp.sqrt(jnp.sum(x**2))
  return jnp.where(norm == 0., x, x/norm)

Would such a logic be potentially implementable? Or would an epsilon be actually preferable (for example in adam, the epsilon is preferred).

GrantMcConachie commented 2 months ago

Hi @vroulet! Yes I think this is possible! I will work on it and let you know.

GrantMcConachie commented 2 months ago

Hello again @vroulet, I tried the following

  norm_emb = jnp.linalg.norm(embeddings, axis=1, keepdims=True)
  norm_emb = jnp.where(norm_emb == 0.0, 1.0, norm_emb)
  embeddings = embeddings / norm_emb
  xcs = jnp.matmul(embeddings, embeddings.T) / temperature

to calculate the cross entropy, rather than the cosine_similarity function with the epsilon. This gives the same cosine similarity matrix, however the gradient resulted in NaNs.

I also tried this

xcs = jnp.where(jnp.isnan(xcs), 0.0, xcs)

keeping the 0.0 epsilon value for the cosine_similarity calculation and this also resulted in NaNs in the gradient.

I am out of ideas of other ways to implement jnp.where(), so I believe the best way to go about this is to add the epsilon in the cosine similarity! Let me know if you have any more suggestions for implementing jnp.where() though!

vroulet commented 2 months ago

Hello @GrantMcConachie, Thanks for trying! We may have had a misunderstanding. The issue you get is when all labels the same. This should not have anything to do with the embeddings, no? So changing the cosine similarity by adding an epsilon won't solve that issue. (It would help though if the embeddings are 0 vectors, but this would be a different bug no? And one that may be handled by cosine_similarity rather than ntxent_loss). The issue here is in diffs that would be filled with 0 if all labels are the same. Then you would get xcs_diffs filled with -jnp.inf, hence the bugs. So the first question is: what should we obtain mathematically in that case? If the answer is "you should get + or - infinity" then the NaNs are ok to me. If the answer is "you should get 0. or 1. because in the limit that's what math say" then we should find a way to encode that.

Thanks again for this contribution, this issue made me look at it again and it's well done :)

Ah and btw you may add a doctest in the loss if you are on it. Understanding what should be the proper shapes etc is not necessarily evident for the user and this would help. (look at the docstring of Adam for example, you'll see a section Examples where you can format some code that would appear nicely in the docs).

GrantMcConachie commented 2 months ago

Hi @vroulet! You are definitely right. Adding this epsilon term to the cosine similarity only fixed a 0 vector embedding issue.

For the issue where all labels are the same, I think the loss should be 0. The reason is because if there are no negative pairs (diffs is filled with 0s), the denominator and numerator inside the log should be the same. The loss $l_{i,j} = -log \frac{exp(sim(z_i, zj) / \tau)}{\Sigma{k = 1}^N \mathbb{1}_{k\neq i} exp(sim(z_i, z_k) / \tau)}$ for any given embedding evaluates to $-log \frac{exp(sim(z_i, z_j) / \tau)}{exp(sim(z_i, z_j) / \tau)} = -log(1) = 0$.

I was confused at first because I thought the $\mathbb{1}_{k\neq i}$ term was 1 for each negative pair, but in reality that term is 1 for every pair except a self pair. This version of the loss assumes that every pair outside of ${i,j}$ is a negative pair, so the equation is a little misleading.

The equation from https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss, where I took a lot of inspiration to build this function, is more general in that you don't need to have just 1 positive pair in your embeddings. Here you get the same evaluation: running this loss with all the same label gives you 0.

In conclusion, adding the epsilon term in cosine similarity alleviates the 0 vector embedding problem and the case in which all labels are the same should evaluate to 0 loss. Let me know if you agree!

I will start working on the doctest soon!

GrantMcConachie commented 1 month ago

Hi @vroulet! Just wanted to let you know I added a doctest! Let me know what you think.

fabianp commented 1 month ago

instead of a hard-coded 1e-12, could we perhaps replace it with np.finfo(embeddings.dtype).eps ? this way the epsilon will depend on the dtype of the embeddings (which I believe is what one would want)

GrantMcConachie commented 1 month ago

Hi @fabianp! Yes I can add this in.