Open apivovarov opened 3 days ago
Possible workaround: Install the CPU version of Torch after the Axlern [gpu] dependencies are installed.
pip install -e .[dev]
pip install -e .[gpu]
pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cpu
I am trying to run all the pytests on a GPU instance
To set up the environment, I installed the [dev] and [gpu] dependencies, but encountered the following issue:
This leads to the following conflict:
I am unable to have both torch and jax installed simultaneously.
When nvidia-cudnn-cu12-9.5.1.17 (the newer version) is installed, torch-2.1.1 crashes with the following error:
When nvidia-cudnn-cu12-8.9.2.26 (the older version) is installed, jax crashes with this error:
Approximately 30 test files use the torch package.
I am confused about how to run all pytests on the GPU instance, as I cannot have both torch/torchvision and jax[cuda12] installed at the same time due to these conflicts.
OS: Ubuntu 22.04