AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
875 stars 66 forks source link

Adding example script with custom Loader for PyTorch API documentation #97

Closed timonmerk closed 8 months ago

timonmerk commented 9 months ago

The CEBRA documentation is very comprehensive and presents in a lot of detail the parameterization. In the current form however the focus seems to explain the scikit-learn API and there is no example script for using the PyTorch API: https://cebra.ai/docs/usage.html

But for many options I am unsure how to parametrize them in the scikit-learn API. For example when using discrete behavioral data, it's currently not possible to specify empirical or discretesampling: https://github.com/AdaptiveMotorControlLab/CEBRA/blob/0378db0b2431d0c50a1e9b80aa1b865869586851/cebra/data/single_session.py#L89

I think this is also intended to not overload the cebra.Cebra intialization or the model.fit() function with too many parameters?

Therefore I thought that maybe adding a minimal example in the usage.rst of how a dataloader with "non-scikitlearn API" conform parameters could be used using PyTorch directly:

import numpy as np
import cebra.datasets
from cebra import plot_embedding
import torch

neural_data = cebra.load_data(file="neural_data.npz", key="neural")
# continuous_label = cebra.load_data(
#    file="auxiliary_behavior_data.h5",
#    key="auxiliary_variables",
#    columns=["continuous1", "continuous2", "continuous3"],
# )

discrete_label = cebra.load_data(
    file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
)

# 1. Define Cebra Dataset
InputData = cebra.data.TensorDataset(
    torch.from_numpy(neural_data).type(torch.FloatTensor),
    # continuous=torch.from_numpy(np.array(continuous_label)).type(torch.FloatTensor),
    discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
).to("cpu")

# 2. Define Cebra Model
neural_model = cebra.models.init(
    name="offset10-model",
    num_neurons=InputData.input_dimension,
    num_units=32,
    num_output=2,
).to("cpu")

InputData.configure_for(neural_model)

# 3. Define Loss Function Criterion and Optimizer
Crit = cebra.models.criterions.LearnableCosineInfoNCE(
    # temperature=0.001,
    # min_temperature=0.0001
).to("cpu")

Opt = torch.optim.Adam(
    list(neural_model.parameters()) + list(Crit.parameters()),
    # lr=0.001,
    weight_decay=0,
)

# 4. Initialize Cebra Model
cebra_model = cebra.solver.init(
    name="single-session",
    model=neural_model,
    criterion=Crit,
    optimizer=Opt,
    tqdm_on=True,
).to("cpu")

# 5. Define Data Loader
# loader = cebra.data.single_session.ContinuousDataLoader(
#    dataset=InputData, num_steps=1000, batch_size=200
# )
loader = cebra.data.single_session.DiscreteDataLoader(
    dataset=InputData, num_steps=1000, batch_size=200, prior="uniform"
)

# 6. Fit model
cebra_model.fit(loader=loader)

# 7. Transform embedding
TrainBatches = np.lib.stride_tricks.sliding_window_view(
    neural_data, neural_model.get_offset().__len__(), axis=0
)
X_train_emb = cebra_model.transform(
    torch.from_numpy(TrainBatches[:]).type(torch.FloatTensor).to("cpu")
).to("cpu")

# 8. Potentially plot embedding
plot_embedding(
    X_train_emb,
    discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
    markersize=10,
)
MMathisLab commented 9 months ago

Thanks @timonmerk I think that's a great idea! We can also link it to the Allen Demo, which uses the PyTorch API: https://cebra.ai/docs/demo_notebooks/Demo_Allen.html -- would you like to make a PR?

Edit to add, we DO have an example, just to be sure you saw it, https://cebra.ai/docs/usage.html#quick-start-torch-api-example, but agree it's not nearly as expressive as your good suggestion to make the sklearn Quick Start guide!

timonmerk commented 9 months ago

Thanks! Yes, that notebook was super helpful already but might be good to link it also in the usage.rst. I also compiled the documentation and ran my linked example locally, but of course it's a bit difficult to test it since it's in a rst file only..

stes commented 9 months ago

but of course it's a bit difficult to test it since it's in a rst file only..

Actually this is included in the unit tests!

Locally, you can run make test or also directly (adapt as needed)

python -m pytest --ff --doctest-modules -m "not requires_dataset" tests ./docs/source/usage.rst cebra