AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
884 stars 72 forks source link

Fix Device in Multisession Training #44

Closed gonlairo closed 1 year ago

gonlairo commented 1 year ago

This pull request addresses a bug in the multisession training functionality of both the sklearn and torch APIs. Previously, irrespective of the user's preference for GPU execution, the training process was limited to CPU only. Our fix ensures that specified accelerators, such as cuda, are utilized. The problem resided in the DatasetCollection, where CPU was set as the default and users were unable to modify this behavior.

Fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/651.