Closed gonlairo closed 1 year ago
Based on the branch from this pull request, I've also extended the device support to more generic GPU devices with torch-directml
. It's still under active development but already has most of functionalities. I've tested on Windows 10 22H2 with AMD RX 5500 XT, the Demo_decoding notebook works (see image below).
Is the dev team also interested in making CEBRA
available to AMD/Intel GPU users? If yes I'm happy to include my changes to this pull request. The changes to the CEBRA codebase is minimal, and torch-directml
needs to be imported during runtime.
test with CUDA:
This pull request improves device support for sklearn API models, enabling transfer between CPU, CUDA and Apple Silicon GPU environments.
Fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/631