conda-forge / jaxlib-feedstock

A conda-smithy repository for jaxlib.
BSD 3-Clause "New" or "Revised" License
16 stars 25 forks source link

JAX with GPU support #34

Closed gkaissis closed 2 years ago

gkaissis commented 4 years ago

It would be a huge benefit if the JAX condaforge package would come precompiled for GPU and automatically install the necessary packages for GPU support. From my understanding, cudnn and cudatoolkit from the TF 2.2 package should work and if nvcc is needed, it's available in cudatoolkit-dev. This would remove the requirement for a system-wide CUDA installation and allow usage in isolated environments. Would probably also be hugely beneficial for Windows users who want to use JAX on GPU.

Thanks to all!

ericmjl commented 4 years ago

@gkaissis thanks for pinging in. A few intricacies make things difficult to provide a GPU package that is built from source. A full build of JAX on conda-forge's systems would probably be infeasible, given that an entire toolchain involving Bazel is necessary. I've done GPU builds myself on my home GPU tower, and they are tricky and take time :(, and when something failed, errors in compilation that came from Bazel were cryptic enough that I simply gave up. I think @ocefpaf did attempt building JAX for GPUs on conda-forge's compute resources, but also ran into similar issues and then ran out of bandwidth to continue.

The hack that @ocefpaf used to get the CPU-only package into conda-forge was to simply copy over the pre-built PyPI packages that they uploaded. For me, this was good enough, and I cannot express enough thanks to him for making it happen! The same might be possible with the pre-built GPU-enabled packages, but I don't know how the namespacing on conda-forge would work out, as I'm not a core conda-forge developer.

One thing you wrote stood out for me, and I think I need a bit of clarification from the conda-forge folks.

From my understanding, cudnn and cudatoolkit from the TF 2.2 package should work and if nvcc is needed, it's available in cudatoolkit-dev. This would remove the requirement for a system-wide CUDA installation and allow usage in isolated environments.

I had tried building a cuda-enabled centOS8 container so that I could run GPU-enabled JAX on there, but in installing cudnn and cudatoolkit-dev and cudatoolkit into the container I couldn't get it recognized by the pip-installed GPU-enabled JAX. It's as if "something needs something linked somewhere but couldn't find it". Once again, I ran out of steam and never sat down to document what I did, since it was so frustrating to handle (I seriously think NVIDIA needs to simplify the way their stuff gets installed and built on top of ☹️). Container woes aside, I didn't realize that nvcc was packaged in cudatoolkit-dev -- but NVCC has to be built to match the driver version on the host right? Or else nvidia-smi errors out? As you can see, I'm also ignorant about this matter and without the bandwidth to dive in and pitch in productively, since my day job is to solve science problems with models, rather than build packages.

Anyways, a wishlist of things I am hoping to see:

As for your other point:

Would probably also be hugely beneficial for Windows users who want to use JAX on GPU.

That would have to be raised directly with the JAX developers. To the best of my knowledge... they have no plans for this. My hunch is that XLA is complicated enough on *nix systems that they probably don't want to deal with Windows. My hands are tied there too, as I've been an exclusively macOS/Linux user for close to 14 years now, and Windows has become a distant and hazy memory.

Sorry if I come across as sounding rambly, sour and salty about JAX GPU on conda-forge. It's something I would like to see happen too, but I have neither the knowledge nor bandwidth to make it happen, despite seeing the many requests come in here and on the JAX issue tracker. (If I were in grad school still, I might be able to dive really deep here, but alas, in my day job now I use computers to help make medicines, and have traded away the luxury of academic research time.)

And as should be evident, I'm also quite ignorant about the state of GPU packaging on conda-forge, having only just heard @jakirkham's SciPy 2020 talk today. I also know that @jakevdp joined the JAX team recently. Perhaps they could chime in on what needs to happen to make JAX on GPU installation as easy as conda install -c conda-forge jax[cuda]? No pressure though, I know you both have your own day jobs to think about too.

gkaissis commented 4 years ago

Dear Eric,

thanks very much for the extensive comments! What can I do to help? I'm happy to help build the packages, we've got the free compute capacity, but it would have to be automated in some way since I cannot manually monitor the release cycle and I have honestly not the slightest clue about conda packaging. What do you propose?

Regarding Windows, I agree, and so have I. But with the frankly desolate state of machine learning on macOS and the encouraging move of Microsoft towards incorporating CUDA containers in WSL2, I'm seeing myself transitioning away from Apple more and more and I believe it would be a service to the community (who are not all hardcore *nix users) in the name of democratisation to offer the service of a Windows package.

Thanks a lot again and all the best,

George

ericmjl commented 4 years ago

Hi George,

Here’s some things I can propose.

Firstly, I’d suggest engaging with @jakirkham and try to work with him to get the pre-built JAX GPU packages onto conda-forge. I’m sure he’d be happy to work with someone on this issue. You might also have to end up engaging the JAX team directly on their repo to debug anything that shows up.

Secondly, I’d suggest again working with @jakirkham and @ocefpaf to see if you can work through getting the Bazel-based builds (i.e. build from scratch) working on conda-forge.

Thirdly, on Windows, this issue tracker isn’t the right place to ask for a Windows-enabled JAX package. The JAX issue tracker is the right place. If you can convince them that it’s worth their time, or offer help to build it on Windows, you’ll be in luck. Just so you don’t get your hopes too high, you should read through their docs to see whether they’ve documented why they don’t have a Windows package. From my limited understanding, it’s got to do with XLA not being available on Windows. JAX uses XLA, so if XLA isn’t available for Windows, then JAX isn’t going to be available on Windows. You should check and confirm with them. “In the name of democratization” is lofty goal; you’ve also got to know the landscape of limitations to be know what requests can and cannot be fulfilled.

I’m going to close this issue, as I think you have enough information to work with, and there’s no “fix”/“resolution“ available from me. Feel free to ping back regardless if you’d like further advice on what to do.

Cheers, Eric

ocefpaf commented 4 years ago

Thanks @ericmjl! We, conda-forge, just had a meeting about this "hard to build" packages and, even though you were not in that meeting, you made a nice summary there of what needs to happen ;-p

beckermr commented 3 years ago

Hi friends! I am going to reopen this issue and see if I can get it to work properly by repackaging the wheels from google.

rabernat commented 2 years ago

Just adding a big 👍 and 🙏 to @beckermr's comment above. Getting a GPU-compatible Jax installable from Conda Forge would be huge! Thanks to all the volunteers here.

wolfv commented 2 years ago

We could try this again with the changes from the bazel-toolchain in the tensorflow recipe, and the 32 core cloud machines to build the recipes.

I am busy with packagingcon for a while though, and then a week on vacation :)

wolfv commented 2 years ago

trying stuff over in #72

ngam commented 2 years ago

With a lot of twits, this is almost ready, pending some reviews --- I continued what @wolfv started, and there was some very annoying bug/behavior where the enabling cuda was overriding our custom toolchain...

https://anaconda.org/ngam/jaxlib

ngam commented 2 years ago

If people are eager to test, I only uploaded the py38 version. Happy to upload 39 and 310 if people want. Just let me know. I am actually not 100% sure everything worked correctly, so testing is very important. I specifically don't know how many cycles the compilation should be taking (no available record upstream to my knowledge). Hopefully xhochy can verify and make sure everything is good to go. (Edit most artifacts are here: https://dev.azure.com/conda-forge/feedstock-builds/_build/results?buildId=506049&view=artifacts&pathAsName=false&type=publishedArtifacts, we are still dealing with a fem timeouts and dropped connections, but we may be able to get this all on the CI!)

rabernat commented 2 years ago

Hi @ngam - thanks for working on this! I would happily test it in our pangeo images; however, we would need a python 3.9 version. If you can upload one, I will test it asap.

andsmi97 commented 2 years ago

Any updates on making this publicaly available?