jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Add support for other GPUs (than NVIDIA) #2012

Closed ricardobarroslourenco closed 1 year ago

ricardobarroslourenco commented 4 years ago

Is it possible to run JAX on other GPU architectures other than NVIDIA (ex.: Intel, AMD)?

reza-amd commented 3 years ago

@proutrc I just checked this with our released docker container, and your code snippet is working fine there. I will work on this issue and provide the fix soon, but I highly recommend to use our released docker container for now. Please checkout https://hub.docker.com/repository/docker/rocm/jax to use the container.

proutrc commented 3 years ago

@reza-amd any update on the out of container use case? I can see the system provided ld.lld, confirming it is in my $PATH, but I am unsure why Jax cannot. It seemed like you maybe alluded to knowing the issue. Anyways, just checking in.

Further, we are pursuing the use of a container as well, which does not appear to have this issue. Though, we are interested in the out of container use case too.

secondspass commented 3 years ago

@reza-amd Do you know where I can find the Dockerfile that was used to build https://hub.docker.com/repository/docker/rocm/jax ? I'm trying to build Jax from source in my own container and run it with Singularity, but Jax in the container is unable to detect the GPU (even though rocm-smi and tensorflow in the container are able to detect the presence of a GPU). However, your rocm/jax container (that I pull from dockerhub and build with Singularity) does detect the GPU properly so I'm trying to figure out what I missed in my attempt.

reza-amd commented 3 years ago

@secondspass Did you run the build/build.py as following?

python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-x.x.x  --rocm_amdgpu_targets=gfx9xx --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow
pip3 install --force-reinstall dist/*.whl && pip3 install --force-reinstall .

Please note all the Xs needs to be determined based on your setup. For finding the rocm_amdgpu_targets value, you can look at the contents of /opt/rocm-x.x.x/bin/target.lst. It shoudl show you all the supported backends for your rocm version. you can feed all of them with a comma separated string e.g. --rocm_amdgpu_targets=gfx900,gfx906,gfx908.

secondspass commented 3 years ago

@reza-amd ah okay, that makes sense. I'll try that.

reza-amd commented 3 years ago

@proutrc sorry for my late response. I missed the github notification for your call.

Installing outside of the container should not be different from the container as long as the ROCm has been installed properly and build/build.py called with correct arguments.

I would recommend to start with an ubuntu container and install ROCm on it and build JAX from source there. If everything goes well inside the container, then you should be able to repeat this on the bare metal.

secondspass commented 2 years ago

@reza-amd I'm building with Podman and then converting with Singularity to run on our machine. Below is the Dockerfile that I use. I'm having to build both Jax and Tensorflow from scratch because I can't use the existing rocm Jax or rocm Tensorflow base containers that are available. Using the rocm/tensorflow container and building Jax complains that the Python version needs to be greater than 3.7. Using the rocm/jax container as the base and building tensorflow causes Jax to break because Jax is built with numpy 1.21 and Tensorflow will forcibly replace that with numpy 1.19 and creates errors when loading Jax. I tried using the below Dockerfile for building on my own but I'm getting an error you see at the bottom when I try to see if Jax can see the GPUs (even though Tensorflow can see the GPUs just fine). It would be helpful if the Dockerfile that you use to build your rocm/jax container can be shared. It would be helpful to see what I might have missed that is causing the error that I get. I have tried your rocm/jax container on its own and it does detect the GPU, so it's not an issue with containers or jax itself on our machine. It's something wrong with the way I'm building and I don't what is causing it.

I do wonder if I have to build the container image on the same machine that the GPUs are on. Right now, I'm trying to build on a VM on a macbook, pushing to Dockerhub, and pulling on our AMD machine. Could that be a source of error?

Let me know if we should take this conversation elsewhere. I can open an issue on the Rocm fork of Jax if needed.

Dockerfile:

FROM docker.io/rocm/dev-ubuntu-20.04:4.2-complete

# installing bazel and others
RUN apt-get -y install curl gnupg
RUN curl -fsSL https://bazel.build/bazel-release.pub.gpg | gpg --dearmor > bazel.gpg
RUN mv bazel.gpg /etc/apt/trusted.gpg.d/ \
    && echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list 
RUN curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - \
    && sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.2/ ubuntu main > /etc/apt/sources.list.d/rocm.list' \
    && apt-get -y --allow-unauthenticated update && DEBIAN_FRONTEND=noninteractive apt-get -y install python3-pip git  \
        bazel \
        rocblas \
        rocfft \
        rocrand \
        rccl \
        miopen-hip \
        hipblas \
        hipfft \
        hipsparse \
        miopengemm \
        miopenkernels-gfx900-56kdb \
        miopenkernels-gfx900-64kdb \
        miopenkernels-gfx906-60kdb \
        miopenkernels-gfx906-64kdb \
        miopenkernels-gfx908-120kdb \
        miopentensile \
    && apt-get clean \
    && rm -rf /var/lib/apt/lists/*

RUN pip3 install tensorflow-rocm==2.5.0 scipy==1.7.0

RUN mkdir /code
WORKDIR /code

RUN cd /code && git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream && cd tensorflow-upstream && git checkout jax_preview_release \
    && cd /code && git clone https://github.com/ROCmSoftwarePlatform/jax && cd jax && git checkout jax_preview_release
RUN cd /code/jax \
    && python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-4.2.0/ --rocm_amdgpu_targets=gfx900,gfx906,gfx908 --bazel_options=--override_repository=org_tensorflow=/code/tensorflow-upstream
RUN cd /code/jax && pip3 install dist/*.whl && pip3 install  .

# install chex and tree 
RUN cd /tmp && git clone https://github.com/deepmind/tree && cd /tmp/tree && git checkout 0.1.6 && python3 setup.py install && rm -rf /tmp/tree
RUN cd /tmp && git clone https://github.com/deepmind/chex && cd /tmp/chex && git checkout v0.0.7 && python3 setup.py install && rm -rf /tmp/chex

ENV PATH=/opt/rocm-4.2.0/hip/bin:/opt/rocm-4.3.1/bin:$PATH
ENV ROCM_PATH=/opt/rocm-4.2.0
ENV HIP_PATH=/opt/rocm-4.2.0/hip

ENTRYPOINT ["/bin/bash"]

Error I'm getting on my image (I don't get the error on the official rocm/jax image which seems to work fine)

$ singularity shell --cleanenv --rocm rocmtensorflow.sif
Singularity> python3
Python 3.8.10 (default, Sep 28 2021, 16:10:42)
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()[0].device_kind
"hipErrorNoBinaryForGpu: Unable to find code object for all current devices!"
Aborted
reza-amd commented 2 years ago

@secondspass thanks for providing the details. I will go ahead ad try your docker file to see if I can reproduce this issue and provide solution.

I do wonder if I have to build the container image on the same machine that the GPUs are on. Right now, I'm trying to build on a VM on a macbook, pushing to Dockerhub, and pulling on our AMD machine. Could that be a source of error?

You have provided the --rocm_amdgpu_targets=gfx900,gfx906,gfx908, so this should not be an issue.

What is the AMD GPU model that you are using?

secondspass commented 2 years ago

@reza-amd It's an AMD Instinct MI100

secondspass commented 2 years ago

@reza-amd have you had a chance to build and test the Dockerfile?

reza-amd commented 2 years ago

@secondspass I did and I was able to reproduce the issue with your docker file. I will provide some updates here by the end of this week.

reza-amd commented 2 years ago

@secondspass could you please try the following docker file?

FROM docker.io/rocm/tensorflow-autobuilds:latest

# install jax requirements
RUN pip3 install six wheel

# Build and install jax
RUN mkdir /code
WORKDIR /code
RUN cd /code && git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream  \
        && cd /code && git clone https://github.com/google/jax.git \
        && cd /code/jax && git checkout 402e3c128d3d211733cabd099089e080146322b6

RUN cd /code/jax \
        && python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-4.5.0/ --rocm_amdgpu_targets=gfx900,gfx906,gfx908,gfx90a,gfx1030 --bazel_options=--override_repository=org_tensorflow=/code/tensorflow-upstream
RUN cd /code/jax && pip3 install dist/*.whl && pip3 install  .

# install chex and tree
RUN cd /tmp && git clone https://github.com/deepmind/tree && cd /tmp/tree && git checkout 0.1.6 && python3 setup.py install && rm -rf /tmp/tree
RUN cd /tmp && git clone https://github.com/deepmind/chex && cd /tmp/chex && git checkout v0.0.7 && python3 setup.py install && rm -rf /tmp/chex

ENTRYPOINT ["/bin/bash"]
secondspass commented 2 years ago

@reza-amd Hi Reza, it failed during the python3 build/build.py step

[5,921 / 6,437] Compiling mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp; 11s local
ERROR: /root/.cache/bazel/_bazel_root/ce5d68b00643b5fb24a0c8f31c678264/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/BUILD:581:11: Compiling tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc failed: (Exit 4): crosstool_wrapper_driver_is_not_gcc failed: error executing command
  (cd /root/.cache/bazel/_bazel_root/ce5d68b00643b5fb24a0c8f31c678264/execroot/__main__ && \
  exec env - \
    PATH=/root//bin:/root/.local/bin:/opt/rocm-4.5.0/opencl/bin:/opt/rocm-4.5.0/bin:/opt/rocm-4.5.0/hcc/bin:/opt/rocm-4.5.0/hip/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin \
    PWD=/proc/self/cwd \
    ROCM_PATH=/opt/rocm-4.5.0/ \
    TF_ROCM_AMDGPU_TARGETS=gfx900,gfx906,gfx908,gfx90a,gfx1030 \
  external/local_config_rocm/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc -U_FORTIFY_SOURCE -fstack-protector -Wall -Wunused-but-set-parameter -Wno-free-nonheap-object -fno-omit-frame-pointer -g0 -O2 '-D_FORTIFY_SOURCE=1' -DNDEBUG -ffunction-sections -fdata-sections '-std=c++14' -MD -MF bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops/tf_ops.pic.d '-frandom-seed=bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops/tf_ops.pic.o' -fPIC '-DLLVM_ON_UNIX=1' '-DHAVE_BACKTRACE=1' '-DBACKTRACE_HEADER=<execinfo.h>' '-DLTDL_SHLIB_EXT=".so"' '-DLLVM_PLUGIN_EXT=".so"' '-DLLVM_ENABLE_THREADS=1' '-DHAVE_DEREGISTER_FRAME=1' '-DHAVE_LIBPTHREAD=1' '-DHAVE_PTHREAD_GETNAME_NP=1' '-DHAVE_PTHREAD_GETSPECIFIC=1' '-DHAVE_PTHREAD_H=1' '-DHAVE_PTHREAD_SETNAME_NP=1' '-DHAVE_REGISTER_FRAME=1' '-DHAVE_SETENV_R=1' '-DHAVE_STRERROR_R=1' '-DHAVE_SYSEXITS_H=1' '-DHAVE_UNISTD_H=1' -D_GNU_SOURCE '-DHAVE_LINK_H=1' '-DHAVE_LSEEK64=1' '-DHAVE_MALLINFO=1' '-DHAVE_POSIX_FALLOCATE=1' '-DHAVE_SBRK=1' '-DHAVE_STRUCT_STAT_ST_MTIM_TV_NSEC=1' '-DLLVM_NATIVE_ARCH="X86"' '-DLLVM_NATIVE_ASMPARSER=LLVMInitializeX86AsmParser' '-DLLVM_NATIVE_ASMPRINTER=LLVMInitializeX86AsmPrinter' '-DLLVM_NATIVE_DISASSEMBLER=LLVMInitializeX86Disassembler' '-DLLVM_NATIVE_TARGET=LLVMInitializeX86Target' '-DLLVM_NATIVE_TARGETINFO=LLVMInitializeX86TargetInfo' '-DLLVM_NATIVE_TARGETMC=LLVMInitializeX86TargetMC' '-DLLVM_NATIVE_TARGETMCA=LLVMInitializeX86TargetMCA' '-DLLVM_HOST_TRIPLE="x86_64-unknown-linux-gnu"' '-DLLVM_DEFAULT_TARGET_TRIPLE="x86_64-unknown-linux-gnu"' -D__STDC_LIMIT_MACROS -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -DTF_USE_SNAPPY -DEIGEN_MPL2_ONLY '-DEIGEN_MAX_ALIGN_BYTES=64' -DHAVE_SYS_UIO_H -iquote external/org_tensorflow -iquote bazel-out/k8-opt/bin/external/org_tensorflow -iquote external/llvm-project -iquote bazel-out/k8-opt/bin/external/llvm-project -iquote external/llvm_terminfo -iquote bazel-out/k8-opt/bin/external/llvm_terminfo -iquote external/llvm_zlib -iquote bazel-out/k8-opt/bin/external/llvm_zlib -iquote external/com_google_absl -iquote bazel-out/k8-opt/bin/external/com_google_absl -iquote external/nsync -iquote bazel-out/k8-opt/bin/external/nsync -iquote external/eigen_archive -iquote bazel-out/k8-opt/bin/external/eigen_archive -iquote external/gif -iquote bazel-out/k8-opt/bin/external/gif -iquote external/libjpeg_turbo -iquote bazel-out/k8-opt/bin/external/libjpeg_turbo -iquote external/com_google_protobuf -iquote bazel-out/k8-opt/bin/external/com_google_protobuf -iquote external/zlib -iquote bazel-out/k8-opt/bin/external/zlib -iquote external/com_googlesource_code_re2 -iquote bazel-out/k8-opt/bin/external/com_googlesource_code_re2 -iquote external/farmhash_archive -iquote bazel-out/k8-opt/bin/external/farmhash_archive -iquote external/fft2d -iquote bazel-out/k8-opt/bin/external/fft2d -iquote external/highwayhash -iquote bazel-out/k8-opt/bin/external/highwayhash -iquote external/double_conversion -iquote bazel-out/k8-opt/bin/external/double_conversion -iquote external/snappy -iquote bazel-out/k8-opt/bin/external/snappy -iquote external/local_config_rocm -iquote bazel-out/k8-opt/bin/external/local_config_rocm -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinAttributeInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinAttributesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinDialectIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinLocationAttributesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinOpsIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypeInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/BuiltinTypesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/CallOpInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/CastOpInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/InferTypeOpInterfaceIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/OpAsmInterfaceIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/RegionKindInterfaceIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SideEffectInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SubElementInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/SymbolInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/TensorEncodingIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ParserTokenKinds -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ControlFlowInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/DerivedAttributeOpInterfaceIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/LoopLikeInterfaceIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticBaseIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticCanonicalizationIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/ArithmeticOpsIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/VectorInterfacesIncGen -Ibazel-out/k8-opt/bin/external/llvm-project/mlir/_virtual_includes/StandardOpsIncGen -isystem external/llvm-project/llvm/include -isystem bazel-out/k8-opt/bin/external/llvm-project/llvm/include -isystem external/llvm-project/mlir/include -isystem bazel-out/k8-opt/bin/external/llvm-project/mlir/include -isystem external/nsync/public -isystem bazel-out/k8-opt/bin/external/nsync/public -isystem external/org_tensorflow/third_party/eigen3/mkl_include -isystem bazel-out/k8-opt/bin/external/org_tensorflow/third_party/eigen3/mkl_include -isystem external/eigen_archive -isystem bazel-out/k8-opt/bin/external/eigen_archive -isystem external/gif -isystem bazel-out/k8-opt/bin/external/gif -isystem external/com_google_protobuf/src -isystem bazel-out/k8-opt/bin/external/com_google_protobuf/src -isystem external/zlib -isystem bazel-out/k8-opt/bin/external/zlib -isystem external/farmhash_archive/src -isystem bazel-out/k8-opt/bin/external/farmhash_archive/src -isystem external/double_conversion -isystem bazel-out/k8-opt/bin/external/double_conversion -isystem external/local_config_rocm/rocm -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm -isystem external/local_config_rocm/rocm/rocm/include -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include -isystem external/local_config_rocm/rocm/rocm/include/rocrand -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/rocrand -isystem external/local_config_rocm/rocm/rocm/include/roctracer -isystem bazel-out/k8-opt/bin/external/local_config_rocm/rocm/rocm/include/roctracer -Wno-sign-compare -Wno-stringop-truncation -mavx '-std=c++14' -fno-canonical-system-headers -Wno-builtin-macro-redefined '-D__DATE__="redacted"' '-D__TIMESTAMP__="redacted"' '-D__TIME__="redacted"' '-DTENSORFLOW_USE_ROCM=1' -D__HIP_PLATFORM_HCC__ -DEIGEN_USE_HIP -no-canonical-prefixes -fno-canonical-system-headers -c external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc -o bazel-out/k8-opt/bin/external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/_objs/tensorflow_ops/tf_ops.pic.o)
Execution platform: @local_execution_config_platform//:platform
In file included from external/org_tensorflow/tensorflow/core/common_runtime/inline_function_utils.h:25:0,
                 from external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc:71:
external/org_tensorflow/tensorflow/core/common_runtime/lower_function_call_inline_policy.h:38:1: warning: multi-line comment [-Wcomment]
 // LINT.ThenChange(inline_function_utils.h,\
 ^
In file included from external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint:46:0,
                 from external/org_tensorflow/tensorflow/core/framework/numeric_types.h:24,
                 from external/org_tensorflow/tensorflow/core/framework/allocator.h:26,
                 from external/org_tensorflow/tensorflow/core/framework/tensor.h:23,
                 from external/org_tensorflow/tensorflow/core/framework/attr_value_util.h:24,
                 from external/org_tensorflow/tensorflow/core/framework/node_def_util.h:23,
                 from external/org_tensorflow/tensorflow/core/framework/shape_inference.h:22,
                 from external/org_tensorflow/tensorflow/core/framework/common_shape_fns.h:20,
                 from external/org_tensorflow/tensorflow/core/framework/resource_mgr.h:27,
                 from external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h:29,
                 from external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h:39,
                 from external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc:16:
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:14:41: warning: ignoring attributes on template argument '__m256i {aka __vector(4) long long int}' [-Wignored-attributes]
 typedef eigen_packet_wrapper<__m256i, 10> Packet32q8i;
                                         ^
external/org_tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX.h:15:41: warning: ignoring attributes on template argument '__m128i {aka __vector(2) long long int}' [-Wignored-attributes]
 typedef eigen_packet_wrapper<__m128i, 11> Packet16q8i;
                                         ^
gcc: internal compiler error: Killed (program cc1plus)
Please submit a full bug report,
with preprocessed source if appropriate.
See <file:///usr/share/doc/gcc-7/README.Bugs> for instructions.
Target //build:build_wheel failed to build
INFO: Elapsed time: 16850.261s, Critical Path: 1367.46s
INFO: 5991 processes: 621 internal, 5370 local.
FAILED: Build did NOT complete successfully
ERROR: Build failed. Not running target
FAILED: Build did NOT complete successfully

     _   _  __  __
    | | / \ \ \/ /
 _  | |/ _ \ \  /
| |_| / ___ \/  \
 \___/_/   \/_/\_\

b'\x1b[31mERROR: The project you\'re trying to build requires Bazel 4.1.0 (specified in /code/jax/.bazelversion), but it wasn\'t found in /usr/local/lib/bazel/bin.\x1b[0m\n\nBazel binaries for all official releases can be downloaded from here:\n  https://github.com/bazelbuild/bazel/releases\n\nYou can download the required version directly using this command:\n  (cd "/usr/local/lib/bazel/bin" && curl -fLO https://releases.bazel.build/4.1.0/release/bazel-4.1.0-linux-x86_64 && chmod +x bazel-4.1.0-linux-x86_64)\n'
Downloading bazel from: https://github.com/bazelbuild/bazel/releases/download/4.1.0/bazel-4.1.0-linux-x86_64

Bazel binary path: ./bazel-4.1.0-linux-x86_64
Bazel version: 4.1.0
Python binary path: /usr/bin/python3
Python version: 3.9
NumPy version: 1.19.5
MKL-DNN enabled: yes
Target CPU: x86_64
Target CPU features: release
CUDA enabled: no
TPU enabled: no
ROCm enabled: yes
ROCm toolkit path: /opt/rocm-4.5.0/
ROCm amdgpu targets: gfx900,gfx906,gfx908,gfx90a,gfx1030

Building XLA and installing it in the jaxlib source tree...
./bazel-4.1.0-linux-x86_64 run --verbose_failures=true --override_repository=org_tensorflow=/code/tensorflow-upstream --config=avx_posix --config=mkl_open_source_only --config=rocm --config=nonccl :build_wheel -- --output_path=/code/jax/dist --cpu=x86_64
b''
Traceback (most recent call last):
  File "/code/jax/build/build.py", line 524, in <module>
    main()
  File "/code/jax/build/build.py", line 519, in main
    shell(command)
  File "/code/jax/build/build.py", line 53, in shell
    output = subprocess.check_output(cmd)
  File "/usr/lib/python3.9/subprocess.py", line 424, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
  File "/usr/lib/python3.9/subprocess.py", line 528, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['./bazel-4.1.0-linux-x86_64', 'run', '--verbose_failures=true', '--override_repository=org_tensorflow=/code/tensorflow-upstream', '--config=avx_posix', '--config=mkl_open_source_only', '--config=rocm', '--config=nonccl', ':build_wheel', '--', '--output_path=/code/jax/dist', '--cpu=x86_64']' returned non-zero exit status 1.
Error: error building at STEP "RUN cd /code/jax         && python3 build/build.py --enable_rocm --rocm_path=/opt/rocm-4.5.0/ --rocm_amdgpu_targets=gfx900,gfx906,gfx908,gfx90a,gfx1030 --bazel_options=--override_repository=org_tensorflow=/code/tensorflow-upstream": error while running runtime: exit status 1

So it looks like an internal compiler error with GCC.

reza-amd commented 2 years ago

@secondspass I just tried it again and it is working fine on my side. Could you please run the following to make sure you have the latest TF rocm docker

docker pull rocm/tensorflow-autobuilds:latest

also run docker build with --no-cache such as:

 docker build --no-cache -f my_docker.dockerfile .
reza-amd commented 2 years ago

@secondspass were you able to fix it? I discussed this with my colleagues and your issue might be related to the "copts" argument configured in .bazelrc https://github.com/google/jax/blob/f1261000d2ac1dddfecfe33f33bfeeaf192b747f/.bazelrc#L43

Please try to change the docker file in a way that build.py is being called with different options of target_cpu_features. By default it is set to release which might not be correct for your VM based build. https://github.com/google/jax/blob/f1261000d2ac1dddfecfe33f33bfeeaf192b747f/build/build.py#L324

Also, could you please help me to reproduce this on my side? what VM are you using? What is the CPU model on your macbook ? Can you directly build the dockerfile in your target machine (the host machine with AMD gpus)

secondspass commented 2 years ago

@reza-amd Hi Reza, thanks for the info and for all the help so far. I tried building with both Docker and Podman on the VM and both failed with the same reason I listed before. A colleague built it natively on their own machine elsewhere and that succeeded. So your suggestion that it is the target_cpu_features argument is likely the right answer.

VM: VirtualBox VM running Ubuntu 20.04.3 LTS Macbook processor: 2.6 GHz 6-Core Intel Core i7

At the moment, I am restricted from building directly on the target machine but I'll see if I can get access there. At least it is good to know that trying to build this on a VM can produce different results than building natively.

brettkoonce commented 2 years ago

@reza-amd thank you very much for the notes on how to get JAX working with ROCm. I was able to get Resnet50 training working with a w6800 gpu!

coversb commented 2 years ago

Hi team,

Thanks a lot for support ROCm for jax. Now I have met some issues:

I don't knwo which is the right way to build jax from source (I saw https://hub.docker.com/r/rocm/jax and do checkout branch jax_preview_release) I build with rocm4.0.1, and the device is gfx906.

1.Run unittest(https://jax.readthedocs.io/en/latest/developer.html?highlight=pytest#running-the-tests)

python tests/lax_numpy_test.py --num_generated_cases=5

it shows

2022-01-27 09:21:25.249401: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc:78] Check failed: buffer_slice.offset() + buffer_slice.size() <= base.size() (4 vs. 0)

If I run the program with jaxlib, it will also show this assertion. I think maybe the source code or build way I used is wrong?

2.Try to use rocm4.1 to build jax from source, but it failed in mlir/xla/operator_writer_gen part, I don't known how to get a right llvm tar package

bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: symbol lookup error: bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: undefined symbol: _ZTINSt3_V214error_categoryE

3.When I setup XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE, seems don't pre-alloc gpu memory as the FRACTION set, just malloc 2%-14% GPU RAM

Can you help me? Thanks a lot!!!

coversb commented 2 years ago

Hi team,

Thanks a lot for support ROCm for jax. Now I have met some issues:

I don't knwo which is the right way to build jax from source (I saw https://hub.docker.com/r/rocm/jax and do checkout branch jax_preview_release) I build with rocm4.0.1, and the device is gfx906.

1.Run unittest(https://jax.readthedocs.io/en/latest/developer.html?highlight=pytest#running-the-tests)

python tests/lax_numpy_test.py --num_generated_cases=5

it shows

2022-01-27 09:21:25.249401: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc:78] Check failed: buffer_slice.offset() + buffer_slice.size() <= base.size() (4 vs. 0)

If I run the program with jaxlib, it will also show this assertion. I think maybe the source code or build way I used is wrong?

2.Try to use rocm4.1 to build jax from source, but it failed in mlir/xla/operator_writer_gen part, I don't known how to get a right llvm tar package

bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: symbol lookup error: bazel-out/k8-opt-exec-50AE0418/bin/external/org_tensorflow/tensorflow/compiler/mlir/xla/operator_writer_gen: undefined symbol: _ZTINSt3_V214error_categoryE

3.When I setup XLA_PYTHON_CLIENT_MEM_FRACTION or XLA_PYTHON_CLIENT_PREALLOCATE, seems don't pre-alloc gpu memory as the FRACTION set, just malloc 2%-14% GPU RAM

Can you help me? Thanks a lot!!!

Fixed 2, it's libstdc++ version problems, but still have 'Check failed' error

reza-amd commented 2 years ago

@coversb, is it possible for you to upgrade your ROCm version?

coversb commented 2 years ago

@coversb, is it possible for you to upgrade your ROCm version? @reza-amd Thanks for your reply! Yes, which version do you think is OK? I don't know if there are some incompatible features between higher version and ROCm4.0 , seems lots of header files and lib files changed.

brettkoonce commented 2 years ago

@reza-amd is is possible to get a 5.0 series build (drun image + jax)? is there anything similar for pytorch? thanks in advance!

reza-amd commented 2 years ago

Sorry for my slow response. We have recently released ROCm-5.0 and we have updated JAX accordingly. You can track the status of PR here: https://github.com/google/jax/pull/9584 In the PR source branch, I have provided utility scripts to build a ROCm container with JAX. Please take a look at https://github.com/ROCmSoftwarePlatform/jax/tree/rocm_refactor_jaxlib/build/rocm for more details.

brettkoonce commented 2 years ago

@reza-amd Thanks for the update! I will try test things as soon as possible. More broadly, what are the criteria to close this bug? Things seem to be working reasonably well!

brettkoonce commented 2 years ago

See also: #9864

brettkoonce commented 2 years ago

@reza-amd Thank you again for the help getting docker working! I am able to use jax to build a docker image and then train networks locally. I did a benchmark using flax + resnet50 + imagenet with a batchsize of 256 in fp16 mode.

Here are the results of a wx6800:

I0320 11:30:18.089150 140318779365120 logging_writer.py:35] [500400] steps_per_second=1.096407, train_accuracy=0.8140624761581421, train_learning_rate=3.6358835941996404e-09, train_loss=0.7592905163764954, train_scale=65536.0
I0320 11:30:53.003456 140385056196416 train.py:364] eval epoch: 99, loss: 0.9520, accuracy: 76.26
I0320 11:30:53.004454 140318779365120 logging_writer.py:35] [500400] eval_accuracy=0.7626402378082275, eval_loss=0.9520021080970764

Here is the results of the same code (eg fp16+bs256) on a dual nvidia 3060 (cuda) setup:

I0319 08:20:43.101390 140559431702272 logging_writer.py:35] [500400] steps_per_second=3.186386, train_accuracy=0.8107030987739563, train_learning_rate=3.6358835941996404e-09, train_loss=0.7701781988143921, train_scale=65536.0
I0319 08:20:57.018532 140610530277184 train.py:364] eval epoch: 99, loss: 0.9472, accuracy: 76.45
I0319 08:20:57.019293 140559431702272 logging_writer.py:35] [500400] eval_accuracy=0.7645031809806824, eval_loss=0.9471861124038696

What else would be needed to mark this bug as resolved? I will start a ViT run next, but that will take a few days to complete!

reza-amd commented 2 years ago

@brettkoonce Thanks much for your update and testing our recent changes in ROCm-5.0.

brettkoonce commented 2 years ago

@reza-amd I have made a little bit of progress with ViT and am having some issues with numerical precision on the w6800. The wx6800 is able to train models using a batchsize of 128 but I get reduced accuracy compared to a reference run on a TPU.

w6800 results:

I0415 08:36:18.115095 139965291800320 logging_writer.py:35] [900810] valid_loss=3.451379, valid_prec@1=0.423600

TPU-v2, batch size of 128 (all other code identical):

I0504 16:20:31.839697 140592298116864 logging_writer.py:35] [900810] valid_loss=3.074142, valid_prec@1=0.491300

Second run with different TPU, same code config:

I0504 16:20:37.937061 140423741003520 logging_writer.py:35] [900810] valid_loss=3.013428, valid_prec@1=0.498340

I had similar results (eg lower performance on AMD) when I did my tests with 4.5.0 last year. Do you have any ideas on why this would happen/suggestions for how to improve things?

hawkinsp commented 2 years ago

@brettkoonce When comparing against TPU, a key thing to be careful of is that the default matmul and convolution precision on TPU is bfloat16 inputs with float32 accumulation. Try setting jax_default_matmul_precision to float32, which although slower should give numerics closer to typical GPUs. Just because the AMD GPU loss is worse, doesn't mean that it's necessarily that the AMD GPU implementation is doing something wrong. (It might be! But I'd try to rule out known quantities.)

brettkoonce commented 2 years ago

@hawkinsp I am using the scenic vit demo in float32 mode (set like so for data / model), for what it's worth. Are there additional settings I should investigate?

I am doing a nvidia run currently with the same configuration and will report when that is finished.

brettkoonce commented 2 years ago

Here is the result when using Nvidia hardware (4x3060) with the same configuration:

I0518 07:44:17.425569 140335808247552 logging_writer.py:35] [900810] valid_loss=3.027929, valid_prec@1=0.495140
hawkinsp commented 2 years ago

@brettkoonce Perhaps move this to a new bug? But my suggestion would be: can you minimize it to a small self-contained test case? That's what I would do, if I had access to the hardware and were debugging it. You might consider comparing the results of a single training step between CPU and GPU, or between the two GPUs.

brettkoonce commented 2 years ago

ROCm build scripts have been failing for ~2 months, see #10162.

brettkoonce commented 2 years ago

Jax 04b751c549a2c03681bd95cb3796df78cb26bc5b is building with rocm 5.2!

brettkoonce commented 1 year ago

Jax 09794bee5f6c33d2ae6552e337e27f97916e14f9 is building with rocm 5.4!

stephensrmmartin commented 1 year ago

Hey @brettkoonce

I am trying to compile jax with jaxlib for rocm on arch linux, and just cannot get a functional combination of things to work.

I was able to compile jaxlib 4.6 and 4.9, but errors occurred at runtime (including seg faults).

Are you able to share which commits/releases/tags you used for jax/jaxlib, xla, (tensorflow if you still used that repo), and which build options you used?

ricardobarroslourenco commented 1 year ago

After some time pinging back on this issue, what an excellent discussion. Is anyone lucky enough to run JAX on ARM architecture (such as the Apple Silicon processors)?

hawkinsp commented 1 year ago

@ricardobarroslourenco Yes. JAX has supported CPU-only execution on Apple hardware for many releases, and there is a new and experimental Apple GPU plugin (https://github.com/google/jax#pip-installation-apple-gpus). (Note: experimental).

In fact, I think I'm going to declare this issue fixed, because at this point we now have at my last count four GPU vendors (NVIDIA, AMD, Apple, Intel) that support JAX to some degree, so I think we can say "we support multiple GPU vendors". We're working on better integration, better testing, and easier release processes for all of them.

Feel free to file new bugs specific to particular hardware vendors!

JoeyTeng commented 1 year ago

Just a quick comment, will it be better to mention the installation guide for ROCm devices in the README, right before the Apple Metal devices section? What do you think @hawkinsp @brettkoonce ?

brettkoonce commented 1 year ago

Grabbag of responses:

@hawkinsp +1 closing this as well, glad to have helped!

@stephensrmmartin

Are you able to share which commits/releases/tags you used for jax/jaxlib, xla, (tensorflow if you still used that repo), and which build options you used?

The pattern I have had luck with (ROCm 4.5 and up) is: 1) Latest Ubuntu linux LTS (supported by ROCm) with HKE addon. 2) Full ROCm install using the installer. 3) Install Docker with hardware extensions enabled --> 4) Then build rocm + jax inside said container, able to talk to device using the instructions in the AMD rocm guide 5) You should now be able to run python inside the docker environment, import jax + call jax.devices() to verify things are working together. 6) (optional) then pin/freeze said image and use it as a base for experiments.

It's not super-turnkey but it definitely works!

@JoeyTeng With the amount of customization ROCm requires, keeping it inside the docker build sub-folder (eg where it's at right now) would be where I would keep it going forward. The jax part works fine but ROCm needs more maturity in general before I can recommend it to new ML practitioners (eg having it on the primary README).