AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
875 stars 66 forks source link

Fix broadcasting in InfoNCE loss #86

Closed stes closed 10 months ago

stes commented 10 months ago

This PR fixes https://github.com/AdaptiveMotorControlLab/CEBRA/issues/48 in the numerically stabilized version of the InfoNCE loss function. It also adds numerical tests for the infoNCE implementation.

Note: Although we verified that this change does not meaningfully influence the algorithm outputs (e.g., in the demo notebooks), it is not advised to do model comparisons across models trained before (up to 0.3.0rc1) and after (from 0.3.0rc2 onwards) this modification.

Thanks to @mudphudwang for flagging.

Fix https://github.com/AdaptiveMotorControlLab/CEBRA/issues/48 Fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/658