AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
920 stars 77 forks source link

[WIP] Add improved goodness of fit implementation #190

Open stes opened 1 month ago

stes commented 1 month ago

This adds a better goodness of fit measure. Instead of the old variant which simply matched the InfoNCE and depends on the batch size, the proposed measure

The conversion is quite simply done via

GoF(model) = log (batch_size_per_session * num_sessions) - InfoNCE(model)

This measure is also used in DeWolf et al., 2024, Eq. (43)

image

Application example (GoF improves from 0 to a larger value during training):

image


Close https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/669

stes commented 1 month ago

TODO: Fix case where batch_size is None image

CeliaBenquet commented 1 day ago

@stes about what I implemented in #202 that I do see here.

I think it would be good to have a really basic function where you provide the loss and the batch size, so that it is easily usable in the pytorch implementation as well.

Also, it would be nice to test for the default CEBRA.batch_size = None, not sure it is handled here.