AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
884 stars 72 forks source link

Label sorting bug in consistency score #27

Closed drsax93 closed 11 months ago

drsax93 commented 1 year ago

Is there an existing issue for this?

Bug description

Hey, I noticed that the function cebra.sklearn.metrics.consistency_score is sorting the dataset_ids that are passed to it, creating a mismatch between the pairs and dataset labels (see snippet). This is then passed onto the plot_consistency function -- does it make sense?

import cebra

mice_ = MICE[:2]
lbl = ['a','d','b','c']
embds = [cebra_w[e][m] for m in mice_ for e in exps]
labels = [lineariseTrack(track[e][m][:,0], track[e][m][:,1], binsize=30)\
          for m in mice_ for e in exps]

scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=embds,
                                                                  labels=labels,
                                                                  between="datasets",
                                                                  dataset_ids=lbl,
                                                                 num_discretization_bins=20)
print(f'original labels: {lbl}')
print(f'consistency labels: {datasets}')
print(f'consistency pairs: {pairs}')

OUTPUT:

consistency labels: ['a' 'b' 'c' 'd']
consistency pairs: [['a' 'd']
 ['a' 'b']
 ['a' 'c']
 ['d' 'a']
 ['d' 'b']
 ['d' 'c']
 ['b' 'a']
 ['b' 'd']
 ['b' 'c']
 ['c' 'a']
 ['c' 'd']
 ['c' 'b']]

Operating System

operating system ubuntu

CEBRA version

cebra version0.2.0

https://github.com/AdaptiveMotorControlLab/CEBRA/pull/25/commits/59e397013c560f8f9d29edb223acef467c7dfe22

Device type

gpu

Steps To Reproduce

No response

Relevant log output

No response

Anything else?

No response

Code of Conduct

stes commented 1 year ago

Hi @drsax93 , given that you are passing non-sorted labels (lbl = ['a','d','b','c']), could you elaborate what the expected output should look like here?

drsax93 commented 1 year ago

The point is that the sorted labels don't reflect the order in which
pairs were computed.

On Jun 26 2023, at 5:49 PM, Steffen Schneider @.***> wrote:

Hi @drsax93 (https://github.com/drsax93) , given that you are passing non-sorted labels (
lbl = ['a','d','b','c'] ), could you elaborate what the expected output should look like here?

Reply to this email directly, view it on GitHub (https://github.com/AdaptiveMotorControlLab/CEBRA/issues/27#issuecomment-1607853755), or unsubscribe (https://github.com/notifications/unsubscribe-auth/AHQLVO2442TTK26APGOODM3XNG4SZANCNFSM6AAAAAAZUMTH2E).

You are receiving this because you were mentioned.

stes commented 1 year ago

So the order of consistency pairs does not reflect the order in the embedding? Or is this just about consistency labels, which are sorted, and the pairs and embeddings are correct?

Do you have a code snippet that demonstrates this issue?

drsax93 commented 1 year ago

It's the second -- simply the sorting that is applied to the consistency labels it is not applied to the consistency pairs (have a look at the output I posted). In the example the labels in pairs go a, d, b, c, the ones in the consitency labels go a, b, c, d instead -- won't this create a mismatch when building the consistency matrix? A possible solution could be passing the dataset_ids directly to the consistency labels, no?

stes commented 1 year ago

Hi @drsax93 , sorry for the slow reply, I just looked into this again.

Computing consistency

I implemented the following test to confirm that the consistency implementation is invariant to the permutation of the inputs:

import cebra
import numpy as np

class Dataset():
    """A test dataset that can be indexed to obtain different permutations"""

    def __init__(self):
        dataset_ids = "a", "b", "c"
        embeddings = [np.random.normal( size = (1000, 3)) for _ in dataset_ids]
        labels = [np.random.uniform(0, 1, size = (1000,)) for _ in dataset_ids]

    def __getitem__(self, order):
        return {
            "dataset_ids" : [dataset_ids[i] for i in order],
            "embeddings" : [embeddings[i] for i in order],
            "labels" : [labels[i] for i in order]
        }

def compute_consistency(kwargs):
    scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(between="datasets", **kwargs)
    return {tuple(pair): score for pair, score in zip(pairs, scores)}

# quick check our function works
assert compute_consistency(dataset[0, 1, 2]) != compute_consistency(dataset[1, 1, 0])

# permutations give same results
assert compute_consistency(dataset[0, 1, 2]) == compute_consistency(dataset[0, 2, 1])
assert compute_consistency(dataset[0, 1, 2]) == compute_consistency(dataset[1, 2, 0])
assert compute_consistency(dataset[0, 1, 2]) == compute_consistency(dataset[2, 1, 0])

You are right that the datasets labels returned by consistency_score are ordered.

Plotting consistency

I think the issue you describe becomes visible in the current situation; let's say we plot the data like this

import matplotlib.pyplot as plt

scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(between="datasets", **dataset[0,1,2])
cebra.plot_consistency(
    scores = scores,
    pairs = pairs,
    datasets = datasets,
    vmin = 0
)
plt.show()

image

Then I can repro the issue you describe here

# This gives an incorrect output
scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(between="datasets", **dataset[0,2,1])
cebra.plot_consistency(
    scores = scores,
    pairs = pairs,
    datasets = datasets,
    vmin = 0
)
plt.show()

image

While passing labels manually gives the correct, non-sorted output:

# This gives the correct output
scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(between="datasets", **dataset[0,2,1])
cebra.plot_consistency(
    scores = scores,
    pairs = pairs,
    datasets = dataset[0,2,1]["dataset_ids"],
    vmin = 0
)
plt.show()

image

Summary

I'll post a fix for the usecase you describe, i.e. computing the consistency score using cebra.sklearn.metrics.consistency_score and then directly passing its outputs to cebra.plot_consistency. The order of entries in the confusion matrix should be determined by datasets passed to cebra.plot_consistency, so this function needs adaptation.

The root issue seems to be https://github.com/AdaptiveMotorControlLab/CEBRA/blob/a21ba0ec17a8c7cc67fc24187f119a52c4132cd0/cebra/integrations/matplotlib.py#L533 which does not accept the pairs as additional inputs to sort the values.

Thanks @drsax93 for flagging --- until the next release, please use sorted labels if that is possible, which will give you the correct result.

stes commented 1 year ago

@drsax93 further discussion on https://github.com/AdaptiveMotorControlLab/CEBRA/pull/54, including the reproduction I discussed above.