Open martin-g opened 2 years ago
Hi, contributions are welcome. See this initial PR for an attempt: https://github.com/conda-forge/jaxlib-feedstock/pull/105. I can help/guide you if you would like to try to add it.
BTW, I verified that jaxlib
builds fine at head on an GCP t2a VM just now with no special treatment, so I'm pretty confident things should work if we simply change the conda-forge
build configuration to also build for aarch64 on Linux. I'd be happy to help here but @ngam would need to tell me what to do to make that happen...
@hawkinsp you're everywhere all of the sudden (in a good way!) and I am even getting a 1-question survey to fill about you in my edu inbox :P (relatedly, #105 was in reaction to the the same people who invited you to give that talk wanting a ppc build for jax...)
We have some options --- I will explain the process more in detail, but briefly:
We also have CUDA builds on aarch and ppc now, so we could go all out and add those too... but probably we should take care of the cpu ones first 😅
Yes, I think it's a great idea to have Linux CUDA aarch64 builds at least because of the upcoming https://www.nvidia.com/en-us/data-center/grace-cpu/ which I'm sure someone will want to use with JAX...
Well... the bot failed, so starting manually #127
Just a heads up: I was able to cross compile jaxlib for AArch64 easily enough, but the JIT compiler target detection isn't correct without making some upstream TensorFlow changes (https://github.com/tensorflow/tensorflow/pull/57182).
So we will not be able to get a working cross-compiled Aarch64 build under 0.3.15 as is and it will need a new jaxlib
release or some patching.
Thank you for the heads. Did you use our tooling (conda-forge) or something else for this? How about when you built native aarch64 version? Did you use our setup here?
I will apply your method here later in the week to see if we can get this sorted.
And you're correct, we can test if the build time is reasonable.
Yup, that's what I did. I suspect the bazel_toolchain
package may do the right thing in the conda
build already for cross-compilation, I'd certainly try that first.
Yup, that's what I did. I suspect the
bazel_toolchain
package may do the right thing in theconda
build already for cross-compilation, I'd certainly try that first.
BTW, there is a way to publish pypi wheels here too. I am not sure if the core team is okay with that, but if you want we can make the pypi wheels here too. An example is numba publishing some of their wheels on anaconda.org: https://anaconda.org/numba/numba/files?type=pypi
I also need Jaxlib for Linux ARM64! Is there any progress on this issue ? Thank you!
I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!
Let's see how #147 pans out (contributions welcome!)
@hawkinsp here's what it stops in #155:
WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/c27b720c93f76662ab6d0e0e507d1fc66ab22119.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1073WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/001d18664f8bcf63af64f10688809f7681dfbf0b.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1074Loading:
1075Loading: 1 packages loaded
1076Analyzing: target //build:build_wheel (2 packages loaded, 0 targets configured)
1077INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform,
1078INFO: ToolchainResolution: Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_execution_config_python//:py_runtime_pair; mismatching values: platform_constraint
1079INFO: ToolchainResolution: Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_python//:py_runtime_pair
1080INFO: ToolchainResolution: Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_python//:py_runtime_pair
1081INFO: ToolchainResolution: Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1082INFO: ToolchainResolution: Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1083INFO: ToolchainResolution: Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_config_cc//:cc-compiler-armeabi-v7a; mismatching values: arm, android
1084INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform, type @bazel_tools//tools/cpp:toolchain_type -> toolchain @local_config_cc//:cc-compiler-aarch64, type @bazel_tools//tools/python:toolchain_type -> toolchain @local_config_python//:py_runtime_pair
1085ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1086ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1087ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1088WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1089ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD
1090ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1091INFO: Elapsed time: 58.229s
1092INFO: 0 processes.
1093FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1094ERROR: Build failed. Not running target
1095
Link to PPC/arm64 builds: https://app.travis-ci.com/github/conda-forge/jaxlib-feedstock/builds/257816827
Note we have our own toolchain that may need thorough updating ... I can work on that
It's now slightly clearer how we have your customizations (see collapsed code below) from https://github.com/google/jax/issues/7097#issuecomment-1216826398 in our tooling in #157, but I am get an error:
ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1079ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1080ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1081WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1082ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD
1083ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1084INFO: Elapsed time: 51.199s
1085INFO: 0 processes.
1086FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1087ERROR: Build failed. Not running target
1088FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1089Traceback (most recent call last):
1090 File "/home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/build.py", line 572, in <module>
1091b''
1092
I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!
Thanks, @hawkinsp !
That's true but Alphafold uses
RUN pip3 install --upgrade pip --no-cache-dir \
&& pip3 install -r /app/alphafold/requirements.txt --no-cache-dir \
&& pip3 install --upgrade --no-cache-dir \
jax==0.3.17 \
jaxlib==0.3.15+cuda11.cudnn805 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and there are only x86_64
wheels at https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Just few lines above in the Dockerfile they install some more dependencies from conda-forge but I'm afraid even if we solve this issue it won't help because it won't depend on the correct CUDA version. I hope I am wrong though!
Sorry for reviving this old issue.
With the introduction of linux-aarch64 support for bioconda, my package colabfold should also work on ARM, except for the missing jaxlib on linux-aarch64.
Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.
However, since I can't selectively disable conda dependencies, I would still need jaxlib to be installable on ARM.
I don't speak for the conda-forge maintainers, but upstream we ship a linux aarch64 pip
wheel.
I would still prefer to provide a single conda command for installation to users, since I have a few dependencies that are not pip
installable.
I am very thankful for all the jaxlib pip variants though! They are super useful!
The main issue here is that we currently have receached the time of CI. Cross-compiled builds, e.g. for linux-aarch64 will take even longer. Once this is fixed, we can look into enabling this here.
Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.
I am missing something or indeed conda packages for jaxlib
on linux-aarch64
without cuda are actually available?
They are available since a year: https://github.com/conda-forge/jaxlib-feedstock/pull/183
Indeed, that was my understanding, but this was not clear from @milot-mirdita in https://github.com/conda-forge/jaxlib-feedstock/issues/125#issuecomment-1995200416 . Could it make sense to rename the issue to "[Feature request]: Add support for CUDA builds on Linux ARM64"?
Solution to issue cannot be found in the documentation.
Issue
According to https://anaconda.org/conda-forge/jaxlib the current supported OS+CPU architectures are:
I'd like to request adding
Linux ARM64
to this list.At the moment AlphaFold project cannot be used on Linux ARM64 due to a missing jaxlib+cuda Python wheel - https://github.com/deepmind/alphafold/issues/528 Currently AlphaFold uses Pip3 to install jaxlib from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html It would be nice if it could use conda-forge instead!
Installed packages
Environment info