AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
920 stars 77 forks source link

Last complete batch indexes for batched inference can go above the input length #199

Open CeliaBenquet opened 1 week ago

CeliaBenquet commented 1 week ago

When using the code implemented in #168, I got an error in some cases where the input size is such that len(inputs) % batch_size < offset.right. That means that the last (incomplete) batch is smaller than offset.right. As a result, I get an error on the penultimate batch as the batch_end_idx is larger than input size.

Code to reproduce (from branch #168):

import cebra.data
import numpy as np

train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), 
                                     continuous=np.random.rand(20111, 2))

model = cebra.CEBRA(max_iterations=10, verbose=True, model_architecture="offset36-model-more-dropout", device="cuda_if_available")
model.fit(train.neural, train.continuous)

embedding = model.transform(train.neural, batch_size=300)

Error:

  File "/CEBRA-dev/cebra/solver/base.py", line 634, in transform
    output = _batched_transform(
             ^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 248, in _batched_transform
    batched_data = _get_batch(inputs=inputs,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/CEBRA-dev/cebra/solver/base.py", line 153, in _get_batch
    _check_indices(indices[0], indices[1], offset, len(inputs))
  File "/CEBRA-dev/cebra/solver/base.py", line 81, in _check_indices
    raise ValueError(
ValueError: batch_end_idx (20117) cannot exceed the length of inputs (20111).

I will propose a solution.