AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
875 stars 66 forks source link

Automated adjustment of `n_bins` when discrete labels are passed #128

Open stes opened 5 months ago

stes commented 5 months ago

As discussed in https://github.com/AdaptiveMotorControlLab/CEBRA/discussions/106 by @FrancescaGuo, when passing a discrete index for consistency calculation, the default n_bins = 100 raises an (expected) error message. The current way to circumvent this error message is to set n_bins to the number of passed labels. However, this could be improved directly in the code: Whenever discrete labels are passed, the binning process required for continuous data could be replaced/adapted.

introspective-swallow commented 5 months ago

In what way should discrete labels be treated? If they should be treated as qualitative variables, then it should be warned whenever values are getting merged (e.g. some embedding has no occurrence of a value and then n_bins is set to something like the minimum number of label values that appear in all cases).

GarrettBlair commented 5 months ago

Adding my experience with this issue as well - If labels are not fully sampled between the max and min values (say label values are 0 and 2, but never 1), then an error is thrown by _coarse_to_fine():

The following example code:

import cebra
import numpy as np
embedding1 = np.random.uniform(0, 1, (1000, 5))
embedding2 = np.random.uniform(0, 1, (1000, 8))
labels1 = np.random.uniform(0, 1, (1000, ))
labels2 = np.random.uniform(0, 1, (1000, ))
# force the labels to only sample 0 or 2
labels1 = np.round(labels1)
labels1[labels1>0] += 1 
labels2 = np.round(labels2)
labels2[labels2>0] += 1 
# Between-runs consistency
scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
                                                    between="runs")
# Between-datasets consistency, by aligning on the labels
scores, pairs, ids_datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
                                                    labels=[labels1, labels2],
                                                    dataset_ids=["achilles", "buddy"],
                                                    between="datasets")

Yields ValueError: Digitalized labels does not have elements close enough to bin index 4. The bin index should be in the range of the labels values. image

I'm using a angular position label with avoidance learning, so not fully sampling the label range is common, and using fewer bins to avoid the error loses too much resolution and merges too many bins. Also, it may be sampled in one set of labels and not the other.

For clarity, a session with this problem has the following unique labels: [ 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 30 31 32 33 34 35 36 ] So the issue comes from the lack of sampling in bins 22-29 (10 degree bins, 0-360)

I think, as you stated @stes, allowing the user to pass discrete labels would alleviate this? Not sure how this would be handled when one has different sampling than another (session one has [0,1,2] and session two only has [0,2] )