AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
884 stars 72 forks source link

InfoNCE Implementation #48

Closed mudphudwang closed 11 months ago

mudphudwang commented 1 year ago

Is there an existing issue for this?

Bug description

Congratulations on this great project and publication!

I was browsing the code and noticed a potential issue with cebra.models.criterion.infonce. I assume that c in the function is just there for numerical stability of logsumexp, and that the function is supposed to return $L = \mathbb{E}_x [-\phi(x_i, y^{+}_i) + \log \sum_{j=1}^{n} e^{\phi(x_i, y^{-}_{ij})}]$? If so, then I think that there might be an error in how c is being broadcasted with neg_dist, which makes the function return incorrect values.

Operating System

Ubuntu 18.04

CEBRA version


Device type

Core i9 / RTX 3090

Steps To Reproduce

#! pip3 install cebra==0.2.0

from cebra.models.criterions import infonce
from torch import randn, allclose, logsumexp

# random positive and negative measures
n = 10
pos = randn(n)
neg = randn(n, n)

# InfoNCE Loss
# :math: `L = \mathbb{E}_x [\phi(x_i, y^{+}_i) + \log \sum_{j=1}^{n} e^{\phi(x_i, y^{-}_{ij})}]`
L = neg.logsumexp(dim=1).mean() - pos.mean()

# CEBRA InfoNCE implementation
cebra_infonce, _, _ = infonce(pos, neg)
cebra_close = allclose(L, cebra_infonce)
print("Cebra InfoNCE --", cebra_close)

# Corrected InfoNCE implementation
def corrected_infonce(pos, neg):

    c = neg.detach().max(dim=1, keepdim=True).values

    pos = pos - c.squeeze(dim=1)
    neg = neg - c

    align = -pos.mean()
    uniform = logsumexp(neg, dim=1).mean()

    return align + uniform, align, uniform

corrected_infonce, _, _ = corrected_infonce(pos, neg)
corrected_close = allclose(L, corrected_infonce)
print("Corrected InfoNCE --", corrected_close)

Relevant log output

Cebra InfoNCE -- False
Corrected InfoNCE -- True

Anything else?

No response

Code of Conduct

stes commented 1 year ago

Thanks for flagging, will look into this/try to repro and then get back!

stes commented 11 months ago

Thanks very much again for flagging this. We updated the loss implementation in PR, including numerical tests against a reference implementation.

In short, we changed the broadcasting, but this has little effect on the performance. Concretely, when testing the 0.2.0 version (old implementation) vs. the 0.3.0rc2 (new implementation) version, we find no significant differences.

Here, for the synthetic benchmarking:

scipy.stats.ttest_rel(new_scores, old_scores)

TtestResult(statistic=1.4435649958049843, pvalue=0.15201838297772313, df=99)

As shown here (note the Y-axis range):
