AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
875 stars 66 forks source link

KeyError with `cebra.fit(adapt=True)` on multi-session embeddings (only) #108

Closed FedeClaudi closed 7 months ago

FedeClaudi commented 8 months ago

Is there an existing issue for this?

Bug description

Hey everyone, thanks for releasing this awesome tool.

I get an error when calling cebra_model.fit(..., adapt=True) to fine tune a model on new data. This only happens if the model was trained on multiple sessions, but not when adapting a model trained on a single session.

The actual usage of CEBRA is within a larger code-base so I don't have a mwe, but these are the steps I do:

1: define CEBRA model with hybrid=False.

        self.model = CEBRA(
            model_architecture=model_architecture,
            batch_size=batch_size,
            temperature=temperature,
            temperature_mode=temperature_mode,
            learning_rate=learning_rate,
            max_iterations=max_iterations,
            time_offsets=time_offsets,
            output_dimension=embedding_size,
            device="cuda",
            conditional="time_delta",  # use behavioral data
            distance=distance,
            verbose=verbose,
            hybrid=hybrid,
            max_adapt_iterations=500,
        )

2: fit self.model on multiple sessions with:

            self.model.fit(
                Xs,
                Ys,
                adapt=False,
            )

where Xs and Ys are lists of np.ndarray with neural and behavioral (continuous) data from multiple experimental sessions.

3: adapt the model to new data:

            self.model.fit(  # single session
                X_new,
                Y_new,
                adapt=True,
            )

with X_new, Y_new being arrays with the data for a single new session.

I get KeyError: '0.net.0.weight' - this is the CEBRA part of the error stack: image

I did some digging and the problem is that the adapt_model created here has different keys in .state_dict() compared to self.model_. adapt_model keys:

odict_keys(['net.0.weight', 'net.0.bias', 'net.2.module.0.weight', 'net.2.module.0.bias',  
'net.3.module.0.weight', 'net.3.module.0.bias', 'net.4.module.0.weight', 'net.4.module.0.bias',        
'net.5.weight', 'net.5.bias'])

self.model_ keys:

['0.net.0.weight', '0.net.0.bias', '0.net.2.module.0.weight',
'0.net.2.module.0.bias', '0.net.3.module.0.weight', '0.net.3.module.0.bias', '0.net.4.module.0.weight',
'0.net.4.module.0.bias', '0.net.5.weight', '0.net.5.bias', '1.net.0.weight', '1.net.0.bias',
'1.net.2.module.0.weight', '1.net.2.module.0.bias', '1.net.3.module.0.weight', '1.net.3.module.0.bias',
'1.net.4.module.0.weight', '1.net.4.module.0.bias', '1.net.5.weight', '1.net.5.bias', '2.net.0.weight',
'2.net.0.bias', '2.net.2.module.0.weight', '2.net.2.module.0.bias', '2.net.3.module.0.weight',
'2.net.3.module.0.bias', '2.net.4.module.0.weight', '2.net.4.module.0.bias', '2.net.5.weight',
'2.net.5.bias', '3.net.0.weight', '3.net.0.bias', '3.net.2.module.0.weight', '3.net.2.module.0.bias',  
'3.net.3.module.0.weight', '3.net.3.module.0.bias', '3.net.4.module.0.weight', '3.net.4.module.0.bias',
'3.net.5.weight', '3.net.5.bias']

When I first train CEBRA on a single session self.model_'s keys match those of adapt_model. Is training on multiple sessions + fine-tuning on new a new one not allowed? Am I doing something wrong in using CEBRA?

Thanks, Federico

Operating System

Windows

CEBRA version

cebra version: '0.2.0'

Device type

gpu: NVIDIA GeForce RTX 3080

Steps To Reproduce

No response

Relevant log output

No response

Anything else?

No response

Code of Conduct

FedeClaudi commented 8 months ago

Minimum example:

import numpy as np
import cebra

timesteps = 10
neurons = 50
out_dim = 8
n_sessions = 3

neural_data = [np.random.normal(0,1,(timesteps, neurons)) for _ in range(n_sessions)]
continuous_label = [np.random.normal(0,1,(timesteps, 3)) for _ in range(n_sessions)]

multi_cebra_model = cebra.CEBRA(batch_size=512,
                                output_dimension=out_dim,
                                max_iterations=10,
                                max_adapt_iterations=10,
                                verbose=True
                                )

multi_cebra_model.fit(neural_data[:2], continuous_label[:2], adapt=False)
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True)

which gives

Traceback (most recent call last):
  File "c:\Users\feder\Documents\github\Omen.py\paper\dev.py", line 94, in <module>
    multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True)
  File "C:\Users\feder\.conda\envs\omen\lib\site-packages\cebra\integrations\sklearn\cebra.py", line 1174, in fit
    self._adapt_fit(X,
  File "C:\Users\feder\.conda\envs\omen\lib\site-packages\cebra\integrations\sklearn\cebra.py", line 1116, in _adapt_fit
    self.state_ = self._adapt_model(X, *y)
  File "C:\Users\feder\.conda\envs\omen\lib\site-packages\cebra\integrations\sklearn\cebra.py", line 963, in _adapt_model
    adapted_dict[k] = adapt_model.state_dict()[k]
KeyError: '0.net.1.weight'

However, if I do:

multi_cebra_model.partial_fit(neural_data[:2], continuous_label[:2], )
multi_cebra_model.partial_fit(neural_data[2], continuous_label[2], )

that seems to work (i.e. I don't get an error), however my understanding is that partial_fit updates all network parameters instead of just learning a new first-layer like adapt=True does.

Also this also errors:


multi_cebra_model.partial_fit(neural_data[:2], continuous_label[:2], )
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True)
stes commented 8 months ago

Hi @FedeClaudi , thanks for the detailed report. The intended behavior would be to adapt the model by running

multi_cebra_model.fit(neural_data[:2], continuous_label[:2], adapt=False)
multi_cebra_model.fit(neural_data[2], continuous_label[2], adapt=True)

as you posted. I will investigate how we could fix this for the multi-session training, seems to be an issue with model loading/replacing the first layer in a multi-session model.

MMathisLab commented 7 months ago

Hi @FedeClaudi right now it's not implemented in the sklearn API, but we will prioritize doing this! thanks for raising the issue.