conda-forge / jaxlib-feedstock

A conda-smithy repository for jaxlib.
BSD 3-Clause "New" or "Revised" License
16 stars 24 forks source link

[Feature request]: Add support for Linux ARM64 #125

Open martin-g opened 2 years ago

martin-g commented 2 years ago

Solution to issue cannot be found in the documentation.

Issue

According to https://anaconda.org/conda-forge/jaxlib the current supported OS+CPU architectures are:

I'd like to request adding Linux ARM64 to this list.

At the moment AlphaFold project cannot be used on Linux ARM64 due to a missing jaxlib+cuda Python wheel - https://github.com/deepmind/alphafold/issues/528 Currently AlphaFold uses Pip3 to install jaxlib from https://storage.googleapis.com/jax-releases/jax_cuda_releases.html It would be nice if it could use conda-forge instead!

Installed packages

# packages in environment at /home/mgrigorov/devel/conda:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main  
_openmp_mutex             5.1                      51_gnu  
brotlipy                  0.7.0           py39hfd63f10_1002  
ca-certificates           2022.3.29            hd43f75c_1  
certifi                   2021.10.8        py39hd43f75c_2  
cffi                      1.15.0           py39h9a3cfec_1  
charset-normalizer        2.0.4              pyhd3eb1b0_0  
colorama                  0.4.4              pyhd3eb1b0_0  
conda                     4.12.0           py39hd43f75c_0  
conda-content-trust       0.1.1              pyhd3eb1b0_0  
conda-package-handling    1.8.1            py39h2f4d8fa_0  
cryptography              36.0.0           py39h3d58568_0  
idna                      3.3                pyhd3eb1b0_0  
ld_impl_linux-aarch64     2.36.1               h0ab8de2_3  
libffi                    3.3                  h7c1a80f_2  
libgcc-ng                 10.2.0              h1234567_51  
libgomp                   10.2.0              h1234567_51  
libstdcxx-ng              10.2.0              h1234567_51  
ncurses                   6.3                  h2f4d8fa_2  
openssl                   1.1.1n               h2f4d8fa_0  
pip                       21.2.4           py39hd43f75c_0  
pycosat                   0.6.3            py39hfd63f10_2  
pycparser                 2.21               pyhd3eb1b0_0  
pyopenssl                 22.0.0             pyhd3eb1b0_0  
pysocks                   1.7.1            py39hd43f75c_0  
python                    3.9.12               hc137634_0  
readline                  8.1.2                h2f4d8fa_1  
requests                  2.27.1             pyhd3eb1b0_0  
ruamel_yaml               0.15.100         py39h2f4d8fa_0  
setuptools                61.2.0           py39hd43f75c_0  
six                       1.16.0             pyhd3eb1b0_1  
sqlite                    3.38.2               h6632b73_0  
tk                        8.6.11               h241ca14_0  
tqdm                      4.63.0             pyhd3eb1b0_0  
tzdata                    2022a                hda174b7_0  
urllib3                   1.26.8             pyhd3eb1b0_0  
wheel                     0.37.1             pyhd3eb1b0_0  
xz                        5.2.5                hfd63f10_1  
yaml                      0.2.5                hfd63f10_0  
zlib                      1.2.12               h2f4d8fa_1

Environment info

active environment : None
       user config file : /home/mgrigorov/.condarc
 populated config files : 
          conda version : 4.12.0
    conda-build version : not installed
         python version : 3.9.12.final.0
       virtual packages : __linux=4.19.90=0
                          __glibc=2.28=0
                          __unix=0=0
                          __archspec=1=aarch64
       base environment : /home/mgrigorov/devel/conda  (writable)
      conda av data dir : /home/mgrigorov/devel/conda/etc/conda
  conda av metadata url : None
           channel URLs : https://repo.anaconda.com/pkgs/main/linux-aarch64
                          https://repo.anaconda.com/pkgs/main/noarch
                          https://repo.anaconda.com/pkgs/r/linux-aarch64
                          https://repo.anaconda.com/pkgs/r/noarch
          package cache : /home/mgrigorov/devel/conda/pkgs
                          /home/mgrigorov/.conda/pkgs
       envs directories : /home/mgrigorov/devel/conda/envs
                          /home/mgrigorov/.conda/envs
               platform : linux-aarch64
             user-agent : conda/4.12.0 requests/2.27.1 CPython/3.9.12 Linux/4.19.90-2207.4.0.0160.oe1.aarch64 openeuler/20.03 glibc/2.28
                UID:GID : 1000:1000
             netrc file : None
           offline mode : False
ngam commented 2 years ago

Hi, contributions are welcome. See this initial PR for an attempt: https://github.com/conda-forge/jaxlib-feedstock/pull/105. I can help/guide you if you would like to try to add it.

hawkinsp commented 2 years ago

BTW, I verified that jaxlib builds fine at head on an GCP t2a VM just now with no special treatment, so I'm pretty confident things should work if we simply change the conda-forge build configuration to also build for aarch64 on Linux. I'd be happy to help here but @ngam would need to tell me what to do to make that happen...

ngam commented 2 years ago

@hawkinsp you're everywhere all of the sudden (in a good way!) and I am even getting a 1-question survey to fill about you in my edu inbox :P (relatedly, #105 was in reaction to the the same people who invited you to give that talk wanting a ppc build for jax...)

We have some options --- I will explain the process more in detail, but briefly:

ngam commented 2 years ago

We also have CUDA builds on aarch and ppc now, so we could go all out and add those too... but probably we should take care of the cpu ones first 😅

hawkinsp commented 2 years ago

Yes, I think it's a great idea to have Linux CUDA aarch64 builds at least because of the upcoming https://www.nvidia.com/en-us/data-center/grace-cpu/ which I'm sure someone will want to use with JAX...

ngam commented 2 years ago

Well... the bot failed, so starting manually #127

hawkinsp commented 2 years ago

Just a heads up: I was able to cross compile jaxlib for AArch64 easily enough, but the JIT compiler target detection isn't correct without making some upstream TensorFlow changes (https://github.com/tensorflow/tensorflow/pull/57182).

So we will not be able to get a working cross-compiled Aarch64 build under 0.3.15 as is and it will need a new jaxlib release or some patching.

ngam commented 2 years ago

Thank you for the heads. Did you use our tooling (conda-forge) or something else for this? How about when you built native aarch64 version? Did you use our setup here?

ngam commented 2 years ago

I will apply your method here later in the week to see if we can get this sorted.

ngam commented 2 years ago

And you're correct, we can test if the build time is reasonable.

hawkinsp commented 2 years ago

Yup, that's what I did. I suspect the bazel_toolchain package may do the right thing in the conda build already for cross-compilation, I'd certainly try that first.

ngam commented 2 years ago

Yup, that's what I did. I suspect the bazel_toolchain package may do the right thing in the conda build already for cross-compilation, I'd certainly try that first.

BTW, there is a way to publish pypi wheels here too. I am not sure if the core team is okay with that, but if you want we can make the pypi wheels here too. An example is numba publishing some of their wheels on anaconda.org: https://anaconda.org/numba/numba/files?type=pypi

julien-faye commented 1 year ago

I also need Jaxlib for Linux ARM64! Is there any progress on this issue ? Thank you!

hawkinsp commented 1 year ago

I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!

ngam commented 1 year ago

Let's see how #147 pans out (contributions welcome!)

ngam commented 1 year ago

@hawkinsp here's what it stops in #155:

WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/tensorflow/runtime/archive/c27b720c93f76662ab6d0e0e507d1fc66ab22119.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1073WARNING: Download from https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/001d18664f8bcf63af64f10688809f7681dfbf0b.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1074Loading: 
1075Loading: 1 packages loaded
1076Analyzing: target //build:build_wheel (2 packages loaded, 0 targets configured)
1077INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform, 
1078INFO: ToolchainResolution:     Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_execution_config_python//:py_runtime_pair; mismatching values: platform_constraint
1079INFO: ToolchainResolution:   Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_python//:py_runtime_pair
1080INFO: ToolchainResolution:   Type @bazel_tools//tools/python:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_python//:py_runtime_pair
1081INFO: ToolchainResolution:   Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_execution_config_platform//:platform: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1082INFO: ToolchainResolution:   Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: execution @local_config_platform//:host: Selected toolchain @local_config_cc//:cc-compiler-aarch64
1083INFO: ToolchainResolution:     Type @bazel_tools//tools/cpp:toolchain_type: target platform @local_config_platform//:host: Rejected toolchain @local_config_cc//:cc-compiler-armeabi-v7a; mismatching values: arm, android
1084INFO: ToolchainResolution: Target platform @local_config_platform//:host: Selected execution platform @local_execution_config_platform//:platform, type @bazel_tools//tools/cpp:toolchain_type -> toolchain @local_config_cc//:cc-compiler-aarch64, type @bazel_tools//tools/python:toolchain_type -> toolchain @local_config_python//:py_runtime_pair
1085ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1086ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1087ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1088WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1089ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668531884423/_build_env/share/bazel/348a535f6622893b4d0b436c261ed568/external/bazel_tools/tools/zip/BUILD
1090ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1091INFO: Elapsed time: 58.229s
1092INFO: 0 processes.
1093FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1094ERROR: Build failed. Not running target
1095
ngam commented 1 year ago

Link to PPC/arm64 builds: https://app.travis-ci.com/github/conda-forge/jaxlib-feedstock/builds/257816827

Note we have our own toolchain that may need thorough updating ... I can work on that

https://github.com/conda-forge/bazel-toolchain-feedstock

ngam commented 1 year ago

It's now slightly clearer how we have your customizations (see collapsed code below) from https://github.com/google/jax/issues/7097#issuecomment-1216826398 in our tooling in #157, but I am get an error:

ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:11:1: indentation error
1079ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:2: Trailing comma is allowed only in parenthesized tuples.
1080ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD:14:3: syntax error at 'outdent': expected expression
1081WARNING: Download from https://mirror.bazel.build/github.com/bazelbuild/rules_cc/archive/081771d4a0e9d7d3aa0eed2ef389fa4700dfb23e.tar.gz failed: class java.io.FileNotFoundException GET returned 404 Not Found
1082ERROR: /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/BUILD.bazel:38:10: every rule of type py_binary implicitly depends upon the target '@bazel_tools//tools/zip:zipper', but this target could not be found because of: no such target '@bazel_tools//tools/zip:zipper': target 'zipper' not declared in package 'tools/zip' defined by /home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/_build_env/share/bazel/c990790b248f6c0b6a739e7d6f0ff41b/external/bazel_tools/tools/zip/BUILD
1083ERROR: Analysis of target '//build:build_wheel' failed; build aborted: Analysis failed
1084INFO: Elapsed time: 51.199s
1085INFO: 0 processes.
1086FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1087ERROR: Build failed. Not running target
1088FAILED: Build did NOT complete successfully (45 packages loaded, 139 targets configured)
1089Traceback (most recent call last):
1090  File "/home/conda/feedstock_root/build_artifacts/jaxlib_1668572828003/work/build/build.py", line 572, in <module>
1091b''
1092
``` load("@local_config_cc//:cc_toolchain_config.bzl", "cc_toolchain_config") package(default_visibility = ["//visibility:public"]) cc_toolchain_suite( name = "toolchain", toolchains = { "k8|compiler": "@local_config_cc//:cc-compiler-k8", "k8": "@local_config_cc//:cc-compiler-k8", "aarch64": ":cc-compiler-aarch64", }, ) cc_toolchain( name = "cc-compiler-aarch64", all_files = "@local_config_cc//:compiler_deps", ar_files = "@local_config_cc//:compiler_deps", as_files = "@local_config_cc//:compiler_deps", compiler_files = "@local_config_cc//:compiler_deps", dwp_files = ":empty", linker_files = "@local_config_cc//:compiler_deps", module_map = None, objcopy_files = ":empty", strip_files = ":empty", supports_param_files = 1, toolchain_config = ":cross_aarch64", toolchain_identifier = "cross_aarch64", ) cc_toolchain_config( name = "cross_aarch64", abi_libc_version = "local", abi_version = "local", compile_flags = [ "-U_FORTIFY_SOURCE", "-fstack-protector", "-Wall", "-Wunused-but-set-parameter", "-Wno-free-nonheap-object", "-fno-omit-frame-pointer", ], compiler = "compiler", coverage_compile_flags = ["--coverage"], coverage_link_flags = ["--coverage"], cpu = "aarch64", cxx_builtin_include_directories = [ "/usr/aarch64-linux-gnu/include", "/usr/lib/gcc-cross/aarch64-linux-gnu/11/include", "/usr/local/include", "/usr/include", "/usr/include/c++/11", "/usr/include/c++/11/backward", ], cxx_flags = ["-std=c++0x"], dbg_compile_flags = ["-g"], host_system_name = "local", link_flags = [ "-fuse-ld=gold", "-Wl,-no-as-needed", "-Wl,-z,relro,-z,now", "-B/usr/bin/aarch64-linux-gnu-", "-pass-exit-codes", ], link_libs = [ "-lstdc++", "-lm", ], opt_compile_flags = [ "-g0", "-O2", "-D_FORTIFY_SOURCE=1", "-DNDEBUG", "-ffunction-sections", "-fdata-sections", ], opt_link_flags = ["-Wl,--gc-sections"], supports_start_end_lib = True, target_libc = "local", target_system_name = "local", tool_paths = { "ar": "/usr/bin/ar", "ld": "/usr/bin/aarch64-linux-gnu-ld", "llvm-cov": "/usr/bin/llvm-cov", "cpp": "/usr/bin/aarch64-linux-gnu-cpp", "gcc": "/usr/bin/aarch64-linux-gnu-gcc", "dwp": "/usr/bin/aarch64-linux-gnu-dwp", "gcov": "/usr/bin/aarch64-linux-gnu-gcov", "nm": "/usr/bin/aarch64-linux-gnu-nm", "objcopy": "/usr/bin/aarch64-linux-gnu-objcopy", "objdump": "/usr/bin/aarch64-linux-gnu-objdump", "strip": "/usr/bin/aarch64-linux-gnu-strip", }, toolchain_identifier = "cross_aarch64", unfiltered_compile_flags = [ "-fno-canonical-system-headers", "-Wno-builtin-macro-redefined", "-D__DATE__=\"redacted\"", "-D__TIMESTAMP__=\"redacted\"", "-D__TIME__=\"redacted\"", ], ) ```
martin-g commented 1 year ago

I don't speak for the conda-forge jaxlib package maintainers, but jaxlib should work fine on ARM64 if you build it from source. So hopefully that can unblock you in the meantime!

Thanks, @hawkinsp !

That's true but Alphafold uses

RUN pip3 install --upgrade pip --no-cache-dir \
    && pip3 install -r /app/alphafold/requirements.txt --no-cache-dir \
    && pip3 install --upgrade --no-cache-dir \
      jax==0.3.17 \
      jaxlib==0.3.15+cuda11.cudnn805 \
      -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

and there are only x86_64 wheels at https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Just few lines above in the Dockerfile they install some more dependencies from conda-forge but I'm afraid even if we solve this issue it won't help because it won't depend on the correct CUDA version. I hope I am wrong though!

milot-mirdita commented 7 months ago

Sorry for reviving this old issue.

With the introduction of linux-aarch64 support for bioconda, my package colabfold should also work on ARM, except for the missing jaxlib on linux-aarch64.

Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.

However, since I can't selectively disable conda dependencies, I would still need jaxlib to be installable on ARM.

hawkinsp commented 7 months ago

I don't speak for the conda-forge maintainers, but upstream we ship a linux aarch64 pip wheel.

milot-mirdita commented 7 months ago

I would still prefer to provide a single conda command for installation to users, since I have a few dependencies that are not pip installable.

I am very thankful for all the jaxlib pip variants though! They are super useful!

xhochy commented 7 months ago

The main issue here is that we currently have receached the time of CI. Cross-compiled builds, e.g. for linux-aarch64 will take even longer. Once this is fixed, we can look into enabling this here.

traversaro commented 4 months ago

Even compilation without CUDA would be quite useful to me, as I could point users to install Colabfold through conda on e.g. a cloud ARM machine for the MSA generation part, and then then run the GPU inference separately on a different machine.

I am missing something or indeed conda packages for jaxlib on linux-aarch64 without cuda are actually available?

xhochy commented 4 months ago

They are available since a year: https://github.com/conda-forge/jaxlib-feedstock/pull/183

traversaro commented 4 months ago

Indeed, that was my understanding, but this was not clear from @milot-mirdita in https://github.com/conda-forge/jaxlib-feedstock/issues/125#issuecomment-1995200416 . Could it make sense to rename the issue to "[Feature request]: Add support for CUDA builds on Linux ARM64"?