AdaptiveMotorControlLab / CEBRA

Learnable latent embeddings for joint behavioral and neural analysis - Official implementation of CEBRA
https://cebra.ai
Other
884 stars 72 forks source link

Improve device support and add support for Apple Silicon chipset (`mps`) #34

Closed gonlairo closed 1 year ago

gonlairo commented 1 year ago

This pull request improves device support for sklearn API models, enabling transfer between CPU, CUDA and Apple Silicon GPU environments.

Fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/631

rob-the-bot commented 1 year ago

Based on the branch from this pull request, I've also extended the device support to more generic GPU devices with torch-directml. It's still under active development but already has most of functionalities. I've tested on Windows 10 22H2 with AMD RX 5500 XT, the Demo_decoding notebook works (see image below).

Is the dev team also interested in making CEBRA available to AMD/Intel GPU users? If yes I'm happy to include my changes to this pull request. The changes to the CEBRA codebase is minimal, and torch-directml needs to be imported during runtime.

image

gonlairo commented 1 year ago

test with CUDA:

Screen Shot 2023-07-17 at 7 17 29 PM