AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
884 stars 72 forks source link

Issue with individually defined model #46

Closed timonmerk closed 1 year ago

timonmerk commented 1 year ago

Is there an existing issue for this?

Bug description

Dear CEBRA team,

first of all amazing tool and brilliant paper! I've already used it in different human invasive recording applications, and it just works really great!

I tried now to use an individually defined model, based on the tutorial provided here: https://cebra.ai/docs/usage.html#model-architecture and copied basically the offset-10 architecture https://github.com/AdaptiveMotorControlLab/CEBRA/blob/00601fb843d9618b44f2b174fe1a80195f8008d8/cebra/models/model.py#L249.

Defining this model also works, but when changing the offset from cebra.data.Offset(5,5) to any value bigger, e.g. cebra.data.Offset(10,10) will give an dimension error.

So suddenly the ref, pos, and neg torch tensors become three dimensinal (they have then shape (given the provided example) [100, 3, 11], which were otherwise 2D with Offset (5,5): [100, 3]

I tried to adapt the other parameters, maybe it's also related to that.. But due to that I am not able to initialize an own model.

Operating System

Windows 11

CEBRA version

0.2.0

Device type

GPU GeForce RTX2070 SUPER

Steps To Reproduce

I wrote here a minimal example reproducing the error:

from cebra import CEBRA
import cebra
import numpy as np

from torch import nn
import cebra.models
import cebra.data
import cebra.models.layers as cebra_layers
from cebra.models.model import _OffsetModel, ConvolutionalModelMixin

@cebra.models.register("my-model") # --> add that line to register the model!
class Offset10Model(_OffsetModel, ConvolutionalModelMixin):
    """CEBRA model with a 10 sample receptive field."""

    def __init__(self, num_neurons, num_units, num_output, normalize=True):
        if num_units < 1:
            raise ValueError(
                f"Hidden dimension needs to be at least 1, but got {num_units}."
            )
        super().__init__(
            nn.Conv1d(num_neurons, num_units, 2),
            nn.GELU(),
            cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            nn.Conv1d(num_units, num_output, 3),
            num_input=num_neurons,
            num_output=num_output,
            normalize=normalize,
        )

    def get_offset(self) -> cebra.data.datatypes.Offset:
        """See :py:meth:`~.Model.get_offset`"""
        return cebra.data.Offset(10, 10)

cebra_model = CEBRA(
    model_architecture = "my-model",
    batch_size = 100,  
    temperature_mode="auto",
    learning_rate = 0.005,  
    max_iterations = 1000, 
    output_dimension = 3,
    device = "cuda",
    conditional="time_delta",
    hybrid=True,
    verbose = True
)

X = np.random.random([1000, 100])
y = np.random.random([1000])

cebra_model.fit(X, y)

Relevant log output

> Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
> The following operation failed in the TorchScript interpreter.
> Traceback of TorchScript (most recent call last):
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\lib\site-packages\cebra\models\criterions.py", line 46, in dot_similarity
>         the similarities between reference samples and negative samples of shape `(n, n)`.
>     """
>     pos_dist = torch.einsum("ni,ni->n", ref, pos)
>                ~~~~~~~~~~~~ <--- HERE
>     neg_dist = torch.einsum("ni,mi->nm", ref, neg)
>     return pos_dist, neg_dist
> RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\models\criterions.py", line 268, in _distance
>     pos, neg = dot_similarity(ref, pos, neg)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\models\criterions.py", line 159, in forward
>     pos_dist, neg_dist = self._distance(ref, pos, neg)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
>     return forward_call(*args, **kwargs)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\solver\base.py", line 430, in step
>     behavior_loss, behavior_align, behavior_uniform = self.criterion(
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\solver\base.py", line 184, in fit
>     stats = self.step(batch)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\integrations\sklearn\cebra.py", line 933, in _partial_fit
>     solver.fit(
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\integrations\sklearn\cebra.py", line 996, in partial_fit
>     self._partial_fit(*self.state_,
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\site-packages\cebra\integrations\sklearn\cebra.py", line 1086, in fit
>     self.partial_fit(X,
>   File "C:\Users\ICN_admin\Documents\Cebra_RatStatesWenger\report_cebra_bug.py", line 53, in <module>
>     cebra_model.fit(X, y)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\runpy.py", line 86, in _run_code
>     exec(code, run_globals)
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\Lib\runpy.py", line 196, in _run_module_as_main (Current frame)
>     return _run_code(code, main_globals, None,
> RuntimeError: The following operation failed in the TorchScript interpreter.
> Traceback of TorchScript (most recent call last):
>   File "C:\Users\ICN_admin\Anaconda3\envs\pn_env\lib\site-packages\cebra\models\criterions.py", line 46, in dot_similarity
>         the similarities between reference samples and negative samples of shape `(n, n)`.
>     """
>     pos_dist = torch.einsum("ni,ni->n", ref, pos)
>                ~~~~~~~~~~~~ <--- HERE
>     neg_dist = torch.einsum("ni,mi->nm", ref, neg)
>     return pos_dist, neg_dist
> RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 0 and no ellipsis was given

Anything else?

No response

Code of Conduct

stes commented 1 year ago

@timonmerk , thanks for your question and the kind words -- I will try to repro your error tomorrow, but I guess that the offset does not fit the actual receptive field of your defined model.

You need to make sure that the receptive field for your model reduces an input with len(offset) time steps to a single timestep (which is then squeezed automatically), is this the case in your example?

Let me know if this makes sense -- we'll in any case improve the error message to make this requirement more clear, thanks for flagging!

timonmerk commented 1 year ago

Thanks a lot for the comment @stes. It took me some time to debug the model forward calls. But I would need to change with Offset(10,10) the kernel size of the first nn.Conv1d to 12. Then, in combination with the three skip connections (each kernel size 3 and stride 1) I get the final output to a "time dimension" of 1. Then the training suceeds.

stes commented 1 year ago

Great to hear --- I will close this issue for now!