apple / axlearn

An Extensible Deep Learning Library
Apache License 2.0
1.88k stars 270 forks source link

Version Conflict Between Torch and JAX for NVIDIA cuDNN-cu12 #858

Open apivovarov opened 3 days ago

apivovarov commented 3 days ago

I am trying to run all the pytests on a GPU instance

To set up the environment, I installed the [dev] and [gpu] dependencies, but encountered the following issue:

pip install -e .[dev]
torchvision-0.16.1 requires torch-2.1.1
torch-2.1.1 requires nvidia_cudnn_cu12-8.9.2.26
pip install -e .[gpu]
jax[cuda12]-0.4.33 needs nvidia-cudnn-cu12 9.5.1.17

This leads to the following conflict:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torch 2.1.1 requires nvidia-cudnn-cu12==8.9.2.26; platform_system == "Linux" and platform_machine == "x86_64", but you have nvidia-cudnn-cu12 9.5.1.17 which is incompatible.
Successfully installed nvidia-cudnn-cu12-9.5.1.17

I am unable to have both torch and jax installed simultaneously.

When nvidia-cudnn-cu12-9.5.1.17 (the newer version) is installed, torch-2.1.1 crashes with the following error:

import torch

ImportError: libcudnn.so.8: cannot open shared object file: No such file or directory

When nvidia-cudnn-cu12-8.9.2.26 (the older version) is installed, jax crashes with this error:

import jax.numpy as jnp
x = jnp.ones((1000, 1000))

FAILED_PRECONDITION: DNN library initialization failed

Approximately 30 test files use the torch package.

I am confused about how to run all pytests on the GPU instance, as I cannot have both torch/torchvision and jax[cuda12] installed at the same time due to these conflicts.

OS: Ubuntu 22.04

apivovarov commented 3 days ago

Possible workaround: Install the CPU version of Torch after the Axlern [gpu] dependencies are installed.

pip install -e .[dev]
pip install -e .[gpu]
pip install torch==2.1.1 torchvision==0.16.1 --index-url https://download.pytorch.org/whl/cpu