AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
920 stars 77 forks source link

Sampling around reference samples (ie conditional='delta') with auxiliary dimension > 1 #56

Closed azylbertal closed 1 year ago

azylbertal commented 1 year ago

Is there an existing issue for this?

Bug description

Fitting a 'CEBRA-behaviour' model with a sampling strategy 'delta' (ie sample pos examples around the value of the auxiliary variables in the ref examples) works when the dimensionality of the context is 1, but not when it is larger.

Operating System

Linux

CEBRA version

0.2.0

Device type

gpu

Steps To Reproduce

Defining and fitting a model like so, with continuous_label.shape = [n_time_points, 2]

cebra_model = CEBRA(
    model_architecture = "offset10-model",
    batch_size = 1024,
    temperature_mode="auto",
    learning_rate = 0.001,
    max_iterations = 2000,
    time_offsets = 10,
    output_dimension = 3,
    device = "cuda_if_available",
    conditional='delta',
    delta=0.1,
    verbose = True
)

cebra_model.fit(neural_data, continuous_label)

Relevant log output

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 1
----> 1 cebra_model.fit(neural_data, continuous_label)

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\integrations\sklearn\cebra.py:1086, in CEBRA.fit(self, X, adapt, callback, callback_frequency, *y)
   1081     self._adapt_fit(X,
   1082                     *y,
   1083                     callback=callback,
   1084                     callback_frequency=callback_frequency)
   1085 else:
-> 1086     self.partial_fit(X,
   1087                      *y,
   1088                      callback=callback,
   1089                      callback_frequency=callback_frequency)
   1090     del self.state_
   1092 return self

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\integrations\sklearn\cebra.py:996, in CEBRA.partial_fit(self, X, callback, callback_frequency, *y)
    994 if not hasattr(self, "state_") or self.state_ is None:
    995     self.state_ = self._prepare_fit(X, *y)
--> 996 self._partial_fit(*self.state_,
    997                   callback=callback,
    998                   callback_frequency=callback_frequency)
    999 return self

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\integrations\sklearn\cebra.py:933, in CEBRA._partial_fit(self, solver, model, loader, is_multisession, callback, callback_frequency)
    928         raise ValueError(
    929             "callback_frequency requires to specify a callback.")
    931 model.train()
--> 933 solver.fit(
    934     loader,
    935     valid_loader=None,
    936     save_frequency=callback_frequency,
    937     valid_frequency=None,
    938     decode=False,
    939     logdir=None,
    940     save_hook=callback,
    941 )
    943 # Save variables of interest as semi-private attributes
    944 self.model_ = model

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\solver\base.py:183, in Solver.fit(self, loader, valid_loader, save_frequency, valid_frequency, decode, logdir, save_hook)
    181 iterator = self._get_loader(loader)
    182 self.model.train()
--> 183 for num_steps, batch in iterator:
    184     stats = self.step(batch)
    185     iterator.set_description(stats)

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\solver\util.py:85, in ProgressBar.__iter__(self)
     83 if self.use_tqdm:
     84     self.iterator = tqdm.tqdm(self.iterator)
---> 85 for num_batch, batch in enumerate(self.iterator):
     86     yield num_batch, batch

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\tqdm\std.py:1180, in tqdm.__iter__(self)
   1177 time = self._time
   1179 try:
-> 1180     for obj in iterable:
   1181         yield obj
   1182         # Update and possibly print the progressbar.
   1183         # Note: does not call self.update(1) for speed optimisation.

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\data\base.py:217, in Loader.__iter__(self)
    215 def __iter__(self) -> Batch:
    216     for _ in range(len(self)):
--> 217         index = self.get_indices(num_samples=self.batch_size)
    218         yield self.dataset.load_batch(index)

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\data\single_session.py:238, in ContinuousDataLoader.get_indices(self, num_samples)
    236 negative_idx = reference_idx[num_samples:]
    237 reference_idx = reference_idx[:num_samples]
--> 238 positive_idx = self.distribution.sample_conditional(reference_idx)
    239 return BatchIndex(reference=reference_idx,
    240                   positive=positive_idx,
    241                   negative=negative_idx)

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\cebra\distributions\continuous.py:281, in DeltaDistribution.sample_conditional(self, reference_idx)
    275     raise ValueError(
    276         f"Reference indices have wrong shape: {reference_idx.shape}. "
    277         "Pass a 1D array of indices of reference samples.")
    279 # TODO(stes): Set seed
--> 281 query = torch.distributions.Normal(
    282     self.data[reference_idx].squeeze(),
    283     torch.ones_like(reference_idx, device=self.device) * self.std,
    284 ).sample()
    286 return self.index.search(query.unsqueeze(-1))

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\torch\distributions\normal.py:51, in Normal.__init__(self, loc, scale, validate_args)
     50 def __init__(self, loc, scale, validate_args=None):
---> 51     self.loc, self.scale = broadcast_all(loc, scale)
     52     if isinstance(loc, Number) and isinstance(scale, Number):
     53         batch_shape = torch.Size()

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\torch\distributions\utils.py:42, in broadcast_all(*values)
     39     new_values = [v if is_tensor_like(v) else torch.tensor(v, **options)
     40                   for v in values]
     41     return torch.broadcast_tensors(*new_values)
---> 42 return torch.broadcast_tensors(*values)

File c:\Users\azylb\miniconda3\envs\cebra\lib\site-packages\torch\functional.py:74, in broadcast_tensors(*tensors)
     72 if has_torch_function(tensors):
     73     return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 74 return _VF.broadcast_tensors(tensors)

RuntimeError: The size of tensor a (2) must match the size of tensor b (1024) at non-singleton dimension 1

Anything else?

No response

Code of Conduct

azylbertal commented 1 year ago

Attempts to pass a 2-element vector to 'delta' also didn't work

stes commented 1 year ago

Hi @azylbertal, this is indeed not implemented right now, but planned for an upcoming release. I'll ping you here once I have updates on a release time-frame, ok?