conda-forge / jaxlib-feedstock

A conda-smithy repository for jaxlib.
BSD 3-Clause "New" or "Revised" License
16 stars 24 forks source link

Possible to remove dependency on cuda-version / __cuda? #254

Closed garymm closed 4 months ago

garymm commented 4 months ago

Comment:

pytorch-cuda can install all of pytorch's cuda dependencies without depending on a system-installed CUDA (AKA __cuda). Is it possible to do the same for jaxlib?

Requiring a system-installed CUDA seems unnecessary since all of the CUDA libraries are available as conda packages. The only thing that the user needs to do is ensure the drivers are installed. This works fine for pytorch, and is nicer since it allows installation of that package without requiring system-requirements.

vyasr commented 4 months ago

I don't know the jaxlib build, but just to clear up one bit of confusion above, __cuda is only for the system-installed driver, as noted in the cf docs:

Conda exposes the maximum CUDA version supported by the installed Nvidia drivers through a virtual package named __cuda

Using that won't impose any extra requirement on what is installed on the system. Such requirements would only be imposed by the libraries contained in the package attempting to load library dependencies that aren't provided by conda packages.

garymm commented 4 months ago

Ah OK well that is badly named but thatnks for clearing it up! Then IIUC PyTorch will allow installation of the CUDA libraries on machines that don't have the NVidia drivers, and at runtime it may fail if you try to use CUDA. Whereas jaxlib only allows installation on machines with the driver. I am not sure which is better, but I guess the inconsitency is what confused me. I guess I'll close this since it was based on a misunderstanding.

vyasr commented 4 months ago

jaxlib actually produces separate builds for machines with and without CUDA. You can see that in the recipe the __cuda dependency is only specified conditionally. If you look at the list of packages produced, you'll see some names include "cuda" in them while others say "cpu". If you try and install on a machine without CUDA drivers installed, you should install one of the cpu variants of the jaxlib build.