jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.21k stars 2.77k forks source link

Add support for other GPUs (than NVIDIA) #2012

Closed ricardobarroslourenco closed 1 year ago

ricardobarroslourenco commented 4 years ago

Is it possible to run JAX on other GPU architectures other than NVIDIA (ex.: Intel, AMD)?

hawkinsp commented 4 years ago

In principle, sure! All we need is XLA to support that architecture.

In practice that means we support at the moment: CPU, NVidia GPU, and TPU.

Happily AMD has been contributing support for AMD GPUs to XLA. We haven't tried it out in JAX, but assuming the XLA support is complete, I see no good reason it wouldn't work with a few small JAX changes. If you are excited about AMD GPUs, we'd certainly welcome contributions enabling that functionality in JAX.

I don't think Intel GPUs have XLA support at the moment, but I wouldn't rule it out in the future as the various compiler toolchains (e.g., XLA, MLIR) progress.

jekbradbury commented 4 years ago

The AMDGPU backend for XLA is being actively developed; these PRs probably have the most up-to-date status (seems like many but not all tests pass?)

One thing to note is that the AMD integrations require that you rebuild XLA from source; there's no way to build a single TF or XLA binary that can use both NVIDIA CUDA and AMD ROCm.

For Intel hardware, I imagine we'd need something like MLIR translation from HLO dialect to nGraph dialect. I'm guessing nobody is actively working on that, but ccing @nmostafa in case Intel has plans in that area.

EelcoHoogendoorn commented 4 years ago

Glad to see this ROCm thing seems to be funded with fulltime developers by AMD. Better late than never I suppose. I hope they learned at least a little from their misadventures in GPGPU, with opencl being half-assedly supported; and in practice if you wanted to get anything done, you had no choice to go with the platform that didnt require you to say, reinvent your FFT libraries from scratch. I hope this time around they realize there is some minimum investment in software theyd be smart to make, if they want to offer a competitive ecosystem. Its crazy to see how much money nvidia has made off this; in the meanwhile google adds a completely new viable hardware and software alternative in the forms of TPUs; and AMD is still working on getting compatibility with any of the software out there. It does not inspire much confidence to be honest; it seems wise to bet against them ever getting out a robust feature complete alternative, if they couldnt even get anything out 4 years ago already. But id love to be wrong about this, and for there to be some genuine competition in desktop ML acceleration in the future.

Cvikli commented 4 years ago

Can someone help me how to use Jax on AMD GPUs? Are there any code snippets we can start with?

Sixzero commented 4 years ago

Any update on the topic? How can that happen tensorflow supports AMD GPU-s but JAX doesn't? Isn't ROCM is the CUDA for AMD GPU-s and inplace replacements of each others?

hawkinsp commented 4 years ago

There's no technical blocker to using JAX on AMD GPUs. We on the JAX team simply don't have access to any AMD GPUs at the moment to develop or test the necessary changes (which are probably not that large, given most of the necessary work has been done in the context of TensorFlow.)

Contributions are welcome!

8bitmp3 commented 4 years ago

The AMDGPU backend for XLA is being actively developed

That's good to know @jekbradbury, thanks

akuz commented 4 years ago

I just wanted to ask, when we are taking about AMD GPUs being supported, is it going to be on all platforms (i.e. including MacOS) or are we talking Linux/Windows only?

jekbradbury commented 4 years ago

I believe the AMDGPU backend support for XLA is based on ROCm, which doesn't support macOS.

inailuig commented 3 years ago

I was able to build jax with initial support for ROCm (AMD GPUs) by compiling it using XLA from ROCmSoftwarePlatform/tensorflow-upstream (update: after https://github.com/tensorflow/tensorflow/pull/45344 you can use upstream TF) and adding a few options to the build scripts.

The code can be found here: inailuig/jax (update: after https://github.com/google/jax/pull/5114 you can use upstream jax)

Executing

import jax
print(jax.devices())
print(jax.devices()[0].device_kind)
x = jax.numpy.array([1.2, 3.4, 5.6])
y = jax.numpy.exp(x)
print(y)

on my RX480 outputs

[GpuDevice(id=0)]
Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.3201168  29.964104  270.4264   ]
2020-11-22 20:40:04.841794: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842168: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842517: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.842866: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-11-22 20:40:04.844206: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which already looks very promising. However there are still things missing such as the custom gpu kernels in jaxlib (cublas, cuda_prng, cusolver).

For those who want to build this: I am running Ubuntu 20.04.1 with rocm 3.9.0 installed using the official instructions. Also it is necessary to install these additional packages: rocm-dev miopen-hip rocfft rocblas rccl hipsparse rocrand rocsolver hipblas Then the whole thing can be built with python3 build/build.py --enable_rocm --rocm_path /opt/rocm-3.9.0 Optionally different amdgpu targets can be specified with --rocm_amdgpu_targets (see here). For now I put in some default targets, however autodetection does also work (by passing "" (an empty string) which overrides the default).

hawkinsp commented 3 years ago

@inailuig That's exciting progress! Nice work! (Sorry for the slow response, many of us were on vacation this last week.)

Technically speaking the cublas/cusolver and cuda_prng kernels are somewhat optional. The cuda_prng kernel is a compile-time optimization and can be safely omitted (at the cost of increased compile time), and cublas/cusolver are only needed for linear algebra support. So it might be possible to check things in even before those pieces work.

I'm curious: is it possible to use upstream TF instead of the ROCm fork? We frequently update our TF (XLA) version, so any ROCm specific fork is likely to be stale.

inailuig commented 3 years ago

@hawkinsp Turns out all that is missing in upstream TF is actually looking for devices with the right platform i.e. some changes in tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc (from this commit: https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/commit/0ba02369635a60dfbc28d5583e521999f519c9f1)

diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
index 4863e5e8165..870007f1dca 100644
--- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc
@@ -57,11 +57,19 @@ xla::StatusOr<xla::DeviceAssignment> GpuClient::GetDefaultDeviceAssignment(

 // Builds an xla::LocalClient for the GPU platform.
 StatusOr<LocalClient*> GetGpuXlaClient() {
+#if GOOGLE_CUDA
   TF_ASSIGN_OR_RETURN(se::Platform * platform,
                       PlatformUtil::GetPlatform("CUDA"));
   if (platform->VisibleDeviceCount() <= 0) {
     return FailedPrecondition("No visible NVidia GPU devices.");
   }
+#else
+  TF_ASSIGN_OR_RETURN(se::Platform * platform,
+                      PlatformUtil::GetPlatform("ROCm"));
+  if (platform->VisibleDeviceCount() <= 0) {
+    return FailedPrecondition("No visible AMD GPU devices.");
+  }
+#endif
   LocalClientOptions options;
   options.set_platform(platform);
   return ClientLibrary::GetOrCreateLocalClient(options);

Do you think we could get something like that upstreamed into TF ?

For cuda_prng and the cublas/cusolver kernels I was also able to get them running (2 or 3 of the lapack functions (cusolver) are not yet implemented in rocsolver, but everything else is there; also requires a few more changes to TF; I will post more once I cleaned it up a bit)

hawkinsp commented 3 years ago

We certainly can upstream something like that. That file is really part of JAX so we can change it as we see fit. You can send PRs to TensorFlow and assign me; I can review.

deven-amd commented 3 years ago

@hawkinsp @inailuig

Thank you for trying out JAX on AMD GPUs. I am on the TF framework team in AMD, and would like to get a better understanding of the TF changes that are required to get JAX working. We would be more than happy to help out.

I also had a question for you. Does JAX have unit-tests that run on GPUs, and if so can you point me to the directions to run them. I would like to get them running on internally on our platform,

thanks again

deven

hawkinsp commented 3 years ago

@deven-amd We'll need to wait for @inailuig to send out their remaining changes to get things to build.

Once those changes are checked in, the best way to do this is probably something like this:

git clone https://github.com/google/jax.git
git clone https://github.com/tensorflow/tensorflow.git /mydir/tensorfow
cd jax
python build/build.py --bazel_options=--override_repository=org_tensorflow=/mydir/tensorflow --enable_rocm
pip install dist/*.whl
pip install -e .
XLA_PYTHON_CLIENT_ALLOCATOR=platform pytest -n auto tests examples

This builds and installs jaxlib with TF (XLA) from head (rather than whatever version we have pinned in our WORKSPACE file). (You can also achieve this by editing the WORKSPACE file; see the comments in that file.)

The XLA_PYTHON_CLIENT_ALLOCATOR avoids using the BFC allocator which preallocates GPU memory, which means that we should be able to run tests in parallel using multiple processes (-n auto enables this).

I should note there are probably a few tests that fail at head on Nvidia GPUs also (https://github.com/google/jax/issues/5067).

hawkinsp commented 3 years ago

See also https://jax.readthedocs.io/en/latest/developer.html#running-the-tests

deven-amd commented 3 years ago

Hi Peter,

Thanks for the quick response.

I will try out the directions you have provided + the docs, to get the JAX unit tests working on the ROCm platform. I expect to work on this next week, will ping you if I run into any issues. In case I do, would you rather I email you directly or file an issue on the JAX github repo?

Thanks

deven

On Fri, Dec 4, 2020 at 12:10 PM Peter Hawkins notifications@github.com wrote:

See also https://jax.readthedocs.io/en/latest/developer.html#running-the-tests

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/2012#issuecomment-738896809, or unsubscribe https://github.com/notifications/unsubscribe-auth/AIZGTXBBS2FMGNUOINZHVXTSTEJYHANCNFSM4KHSBE2Q .

hawkinsp commented 3 years ago

@deven-amd

If there's no reason otherwise, we like to do development in the open so the community can be involved. So I'd file issues/PRs or use Github discussions. You can ping me in any issues or PRs if you want to make sure I take a look!

inailuig commented 3 years ago

@deven-amd Thanks for reaching out, would be great if you could in particular help with fixing the tests which are still failing.

I just opened https://github.com/google/jax/pull/5114 for the remaining build related stuff in jax. In general things seem to be working.

However there are still some tests failing because of bugs (e.g. stuff related to conv, dot_general, triangular solve, ...) Other Features are simply not implemented yet for ROCm in XLA (e.g. TRSM for complex args). For the latter we will have to identify and skip them.

Also there is this error message

 E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

which keeps popping up when the program terminates. @deven-amd would you be able to look into this?


For the BLAS/LAPACK wrappers (i.e. jaxlib/cusolver.py and the related pybind modules but for rocm) I mostly followed what @hawkinsp did for cuda here since its just lots of glue code around roc/cu BLAS/Solver routines). This can be found in https://github.com/google/jax/pull/5115

For this to work we still need a few changes in TF:

  1. custom_call_thunk needs to be enabled and build for rocm: https://github.com/inailuig/tensorflow/commit/44d3a233c6971344d595aacbad1459e1822264cd
  2. in xla_client.py "CUDA" is hardcoded when you try to register a custom call target I suggest we fix this like so:https://github.com/inailuig/tensorflow/commit/be9602a7666eb05edc33faae7825f8401968e885 (This still keeps CUDA as default when you pass 'gpu' unfortunately) Then we can register functions for ROCM like this: xla_client.register_custom_call_target(_name, _value, platform="ROCM") Everywhere else in jax we can keep 'gpu'.
  3. We need to add rocSolver targets to the build scripts somewhere (I think we should add this to TF, although I guess it would also be possibe to add them just to jax) For my attempt at this see: https://github.com/inailuig/tensorflow/commit/606d7933b39f4115f8aea61e25bceb906855b5bf
  4. Not strictly necessary but nice to have: rocm_library, see https://github.com/inailuig/tensorflow/commit/e08f34ca8fe49056407eeaa706556af891d6857d

All of this can be found in https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels (there are 2 more commits which are useful for debugging, but not necessary)

@hawkinsp How should we proceed?

hawkinsp commented 3 years ago
  1. Seems fine: I'd send that as a PR.
  2. Also looks fine to me. I might be tempted to change "gpu" to mean "register both CUDA and ROCM", which we could do by making xla_platform_names a dictionary whose values are a list of names and then register all of them.
  3. Seems plausible, and adding it to TF is probably the better place (that way, TF can share the build rules). I'm a bit surprised that TF doesn't have ROCSolver hooked up already.
  4. Also seems reasonable to me, but I'm not as sure about this.
hawkinsp commented 3 years ago

Retitling this bug to focus on AMD GPUs only; we can open new bugs for other hardware vendors if needed.

deven-amd commented 3 years ago

@inailuig

@deven-amd would you be able to look into this?

How do I go about reproducing the error on my end?

Following the directions provided by @hawkinsp , I am able to build jax and run the tests on CPU platform. Next step is to reproduce the behaviour you see on ROCm. If I understand correctly, this requires change both to TF and JAX.

I am building with

As for the changes on the TF side (1 thru 4) in your post, 1, 3 & 4: I am in the process of those changes to the ROCm fork of TF. Let me know if you plan on creating PRs to get those changes into the TF repo, or if you want to me to push them out once they are in the ROCm fork (i.e. ROCm fork ---> upstream TF) 2: will changing the mapping from 'gpu: 'CUDA' to ` 'gpu' : 'GPU`` also achive the same effect?

inailuig commented 3 years ago

@deven-amd I suspect your gpu is not detected for some reason, so the tests run on the cpu. (did you compile with --enable_rocm?)

See my initial example from above: https://github.com/google/jax/issues/2012#issuecomment-731843904 for how to check if the gpu is detected and used. Also you could try to reproduce the error with it.

If you want to test my rocblas/cublas wrappers as well you should use https://github.com/inailuig/tensorflow/tree/jax-rocm-gpukernels and https://github.com/inailuig/jax/tree/rocm-gpukernels

Otherwise you can use upstream TF and upstream jax (everything necessary is merged now; edit: except rng which wont work unless you remove cuda_prng.py).

deven-amd commented 3 years ago

this is what I get with the simple example

rocm-user@prj47-rack-15:/common/JAX$ python3 simple.py 
2020-12-07 22:25:07.839411: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_driver.cc:982] could not retrieve ROCM device count: HIP_ERROR_NoDevice
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
cpu
[  3.320117  29.964104 270.4264  ]

let me debug this further

deven-amd commented 3 years ago

@inailuig

I am now able to reproduce the error you get with the simple testcase

rocm-user@rocm-framework-14:/common/JAX$ python3 simple.py 
[GpuDevice(id=0)]
Vega 20
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
'+code-object-v3' is not a recognized feature for this target (ignoring feature)
[  3.320117  29.964104 270.42636 ]
2020-12-08 13:12:36.643101: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.643758: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.644247: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.646026: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.650082: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work
2020-12-08 13:12:36.654172: E external/org_tensorflow/tensorflow/stream_executor/rocm/rocm_gpu_executor.cc:614] Deallocating stream with pending work

The '+code-object-v3' is not a recognized feature for this target (ignoring feature) warning is a known issue, and will be gone soon (fix is in ROCm TF fork, will be upstreamed soon)

looking into the error messages now

deven-amd commented 3 years ago

@inailuig @hawkinsp

I have added all the TF side changes in the ROCm TF fork

https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/pull/1198 https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/pull/1201

I will file a PR soon to push those changes from ROCm fork --> upstream TF and cc you guys on it. Until that PR is accepted, pointing org_tensorflow to the develop-upstream branch in the ROCm fork will help with testing for ROCm support.

The Deallocating stream with pending work error message can be ignored for now. It is being incorrectly issued in this case, and we are looking into isolating the cause and fixing it.

hawkinsp commented 3 years ago

@deven-amd That's great! We do need these upstream changes though, because as I mentioned earlier, we often track upstream XLA changes pretty closely. So any fork will rapidly become stale.

Out of curiosity, are there docker containers for building with ROCm? We do our NVidia release builds inside a Docker container (https://cs.opensource.google/jax/jax/+/master:build/build_jaxlib_wheels.sh). No promises, but we might also be able to build AMD linux wheels together with our NVidia wheels, although we would have no way to test them.

inailuig commented 3 years ago

The Deallocating stream with pending work error message can be ignored for now. It is being incorrectly issued in this case, and we are looking into isolating the cause and fixing it.

alright, thanks!

@deven-amd If you want to open a combined PR for everything then that is also fine by me, we can have a discussion on there especially about 2. (Otherwise I also would not mind submitting 1. and/or 2. myself)

@deven-amd I have another question: currently some tests fail because of `rocBLAS does not currently support the TRSM operation for ...complex... (here: https://github.com/tensorflow/tensorflow/blob/6859f52a3fba6714b5360262f190c9649613ac5c/tensorflow/stream_executor/rocm/rocm_blas.cc#L2432 and also the next one)

Afaik rocBLAS does support ctrsm and ztrsm now, so would you be able to add and upstream them to TF? For the meantime I have a workaround on the jax side: https://github.com/inailuig/jax/commit/6a9b49d3ee1045986a6100ee48017333ea927ee1

Artem-B commented 3 years ago

Out of curiosity, are there docker containers for building with ROCm?

@hawkinsp : docker pull rocm/rocm-terminal may work. See https://github.com/RadeonOpenCompute/ROCm-docker

deven-amd commented 3 years ago

@hawkinsp , you can use containers from the following docker repo, as base containers for building with ROCm

https://hub.docker.com/r/rocm/dev-ubuntu-18.04/tags?page=1&ordering=last_updated

/cc @sunway513 for awareness

deven-amd commented 3 years ago

We do need these upstream changes though

@deven-amd If you want to open a combined PR for everything then that is also fine by m

just filed a combined PR - https://github.com/tensorflow/tensorflow/pull/45583.


btw, when I tried to build using the tip of TF, I am getting the following error

ERROR: /home/rocm-user/.cache/bazel/_bazel_root/15199cbabc3b1eef2a9e7002ce358bc9/external/org_tensorflow/tensorflow/compiler/xla/python/BUILD:162:1: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/python:py_client' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
...
external/org_tensorflow/tensorflow/compiler/xla/python/py_buffer.cc:90:16: error: 'bit_cast' is not a member of 'absl'
   return absl::bit_cast<std::uintptr_t>(ptr);
                ^~~~~~~~
...

hoping that this error is transient, and will be resolved on the TF side.


Afaik rocBLAS does support ctrsm and ztrsm now, so would you be able to add and upstream them to TF?

will look into this next.

inailuig commented 3 years ago

btw, when I tried to build using the tip of TF, I am getting the following error

ERROR: /home/rocm-user/.cache/bazel/_bazel_root/15199cbabc3b1eef2a9e7002ce358bc9/external/org_tensorflow/tensorflow/compiler/xla/python/BUILD:162:1: C++ compilation of rule '@org_tensorflow//tensorflow/compiler/xla/python:py_client' failed (Exit 1): crosstool_wrapper_driver_is_not_gcc failed: error executing command 
...
external/org_tensorflow/tensorflow/compiler/xla/python/py_buffer.cc:90:16: error: 'bit_cast' is not a member of 'absl'
   return absl::bit_cast<std::uintptr_t>(ptr);
                ^~~~~~~~
...

hoping that this error is transient, and will be resolved on the TF side.

Also for me. Have to use the version pinned in the jax WORKSPACE for now.

hawkinsp commented 3 years ago

@zhangqiaorjc has a fix coming for that error.

hawkinsp commented 3 years ago

Try at https://github.com/tensorflow/tensorflow/commit/62e3126273046ad0bb0b5837e7b7e4ec1d53598c

inailuig commented 3 years ago

Just a quick status update:

Currently there are only ~60 tests left which fail: failed.txt examples.txt

and some tests which we should skip on ROCm for now: not_implemented.txt other.txt

roblem commented 3 years ago

Just built from jax master and tensorflow master with no errors using build steps from https://github.com/google/jax/issues/2012#issuecomment-738896364 , but getting undefined symbol: rocblas_status_to_string when issuing from jaxlib import rocblas_kernel as described in https://github.com/google/jax/pull/5115#issuecomment-747737669.

inailuig commented 3 years ago

@roblem you can patch the BUILD like this for now:

diff --git a/jaxlib/BUILD b/jaxlib/BUILD
index 2a594a1c..1545e198 100644
--- a/jaxlib/BUILD
+++ b/jaxlib/BUILD
@@ -244,6 +244,8 @@ pybind_extension(
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
         "@local_config_rocm//rocm:rocm_headers",
+        "@local_config_rocm//rocm:rocblas",
+        "@local_config_rocm//rocm:rocsolver",
         "@pybind11",
     ],
 )
roblem commented 3 years ago

Can confirm that fixes the build and I can run the code from https://github.com/google/jax/issues/2012#issuecomment-731843904 (and I built against tensorflow-rocm). Thanks.

hawkinsp commented 3 years ago

That patch has now been merged; you should be able to build the ROCm support from head.

inailuig commented 3 years ago

One of the tests which is still failing for me is testExpm in linalg_test.py for float32. (JAX_ENABLE_X64=1 JAX_NUM_GENERATED_CASES=10000 python3 linalg_test.py --test_targets="Expm" --exclude_test_targets='ExpmFrechet|ExpmGrad')

While trying to investigate this a little bit further I noticed that on repeated calls to jax.scipy.linalg.expm sometimes the correct result would be returned, and sometimes just garbage. This occurs only for single precision (float32), double precision works just fine.

Here is a snipped of code to consistently reproduce this:

testjustxla = False
testx64 = False
testsimplified = False
testnojit = False
teststore = False

# bug happens with both the rocsolver wrapper, and just XLA directly
if testjustxla:
    # disable the jaxlib.rocsovler and use QR from XLA directly for testing:
    import jaxlib; jaxlib.rocsolver = None
    import jax.lib; jax.lib.rocsolver = None

from jax import config
config.update('jax_enable_x64', testx64) # bug only happens with single precision
config.update('jax_log_compiles', True)

# bug happens both with and without jit
config.update('jax_disable_jit', testnojit) # disable jit completely

import scipy; import numpy as np
import jax; import jax.numpy as jnp
from jax._src.scipy.linalg import _calc_P_Q

x = np.array([[0.85039127, 0.28177845, 0.7522477, 0.3023423, 0.20289063],
[0.00165319, 0.20157862, 0.10795689, 0.64915824, 0.8004855],
[0.49672258, 0.2703879,0.29651344, 0.05654442, 0.47397232],
[0.3563708,  0.5359229,0.81798935, 0.37546575, 0.36782074],
[0.33748698, 0.27594268, 0.54881346, 0.7032875, 0.24230194]])

# simplified expm using QR
def expm4(A):
    P, Q, n_squarings = _calc_P_Q(A) # this is already jitted in its source
    qq, rr = jnp.linalg.qr(Q)
    return jax.scipy.linalg.solve_triangular(rr, qq.T@P) # this is already jitted in its source

# bug happens with both functions:
if testsimplified:
    f = expm4
else:
    f = jax.scipy.linalg.expm # this is already jitted in its source

fgpu = jax.jit(f, backend='gpu')
fcpu = jax.jit(f, backend='cpu')
print('scipy:\n', scipy.linalg.expm(x), '\n')
if not testnojit:
    print('jax cpu:\n',fcpu(x), '\n')

x_gpu = jax.device_put(x).block_until_ready()
print('x_gpu: ', x_gpu.device_buffer.device().device_kind)

bak = []
for i in range(1,12):
    res = fgpu(x_gpu).block_until_ready()
    print('jax gpu {}:\n'.format(i), res)
    if teststore:
        bak.append(res) # sometimes bug happens more frequently if we keep the result around
    else:
        del res

and this is the (manually shortened) output on my RX480:

scipy:
 [[3.16592573 1.09933328 2.14301701 1.23653902 1.24814494]
 [0.76042022 1.89144541 1.20449995 1.56615736 1.60165392]
 [1.31930065 0.80527063 2.21742704 0.78474264 1.16609782]
 [1.4727921  1.36080886 2.10902964 2.40654609 1.50741066]
 [1.36134008 1.06076163 1.79722665 1.59356826 2.21621311]] 

WARNING:absl:Compiling expm for args (ShapedArray(float32[5,5]),).
jax cpu:
 [[3.1659255  1.0993332  2.1430166  1.2365389  1.2481449 ]
 [0.76042014 1.8914452  1.2044998  1.5661573  1.6016538 ]
 [1.3193004  0.80527043 2.2174268  0.7847425  1.1660976 ]
 [1.472792   1.3608087  2.1090295  2.4065459  1.5074106 ]
 [1.3613399  1.0607615  1.7972265  1.5935681  2.216213  ]] 

x_gpu:  Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]
WARNING:absl:Compiling expm for args (ShapedArray(float32[5,5]),).
jax gpu 1:
 [[3.1659257  1.0993333  2.1430168  1.236539   1.248145  ]
 [0.7604203  1.8914455  1.2045002  1.5661575  1.6016542 ]
 [1.3193004  0.80527055 2.217427   0.78474253 1.1660978 ]
 [1.472792   1.3608087  2.1090298  2.4065459  1.5074106 ]
 [1.3613399  1.0607615  1.7972267  1.5935681  2.2162132 ]]
jax gpu 2:
 [[ 15817581.   -18044676.    -2986114.2   -5782694.   -17373952.  ]
 [-23044934.    20692924.   -39814160.     7504455.5    5831740.  ]
 [  9685456.      393749.1   19312160.   -12428168.    10070875.  ]
 [ -9280900.    -3877790.5   -9662181.     -127421.97 -29113468.  ]
 [ 20856488.    16251458.    27534514.    24414352.    33953620.  ]]

... more of the same garbage ...

jax gpu 7:
 [[3.1659257  1.0993333  2.1430168  1.236539   1.248145  ]
 [0.7604203  1.8914455  1.2045002  1.5661575  1.6016542 ]
 [1.3193004  0.80527055 2.217427   0.78474253 1.1660978 ]
 [1.472792   1.3608087  2.1090298  2.4065459  1.5074106 ]
 [1.3613399  1.0607615  1.7972267  1.5935681  2.2162132 ]]
jax gpu 8:
 [[ 15817581.   -18044676.    -2986114.2   -5782694.   -17373952.  ]
 [-23044934.    20692924.   -39814160.     7504455.5    5831740.  ]
 [  9685456.      393749.1   19312160.   -12428168.    10070875.  ]
 [ -9280900.    -3877790.5   -9662181.     -127421.97 -29113468.  ]
 [ 20856488.    16251458.    27534514.    24414352.    33953620.  ]]

... more of the same garbage ...

Notice how the first call to expm on the gpu is correct, calls 2-6 are wrong and 7 is again correct

The bug happens with all True/False combinations of the test* variables defined at the beginning of my snippet (with and without jit ....; except with testx64=True the bug goes away)

Some time ago I did a little bit more debugging (manually printing from the tf C++ code to VLOG...), and what I saw was that usually some of the calls to dot_general issued by _precise_dot (here: https://github.com/google/jax/blob/9ccfc9fd48b097c78c2d0fae515ff6bd52c0f681/jax/_src/scipy/linalg.py#L315) would return garbage, causing the wrong result, however I have also seen it happening elsewhere (IIRC in one of the functions called for QR)

I was wondering If anybody else (possibly with a different GPU) can reproduce this?

roblem commented 3 years ago

On the Radeon VII, I am always getting the correct result:

jax gpu 6:
 [[3.1659257  1.0993332  2.143017   1.2365389  1.248145  ]
 [0.7604202  1.8914453  1.2045     1.5661573  1.6016538 ]
 [1.3193005  0.8052705  2.217427   0.78474253 1.1660978 ]
 [1.472792   1.3608086  2.1090295  2.4065459  1.5074105 ]
 [1.3613399  1.0607615  1.7972267  1.5935681  2.216213  ]]
inailuig commented 3 years ago

@roblem good to know that it works correctly on the newer cards.

I think this is ultimately related to https://github.com/RadeonOpenCompute/ROCm/issues/1265, and with AMD dropping support for gfx8xx this means I am probably out of luck...

I re-compiled it with rocm 3.5.1 as suggested by the aforementioned issue (needs reverting https://github.com/tensorflow/tensorflow/commit/d236afda36626f3dd6dfea234413b3b4d62fc9a0, and hardcoding gfx803 somewhere...) and I can confirm that the bug is not present with rocm 3.5.1.

staticdev commented 3 years ago

@inailuig I also confirm 3.5.1 works, seems 3.9 works with some tricks but I didn't test yet (https://github.com/xuhuisheng/rocm-build/blob/master/docs/gfx803.md)

proutrc commented 3 years ago

When attempting to build JAX on our AMD based system, with gfx908, I noticed this at the top of the bazel build output:

ROCm enabled: yes
ROCm toolkit path: /opt/rocm-4.1.0
ROCm amdgpu targets: gfx803,gfx900,gfx906,gfx1010

Does this mean gfx908 is not a currently supported target, only the ones on the list?

inailuig commented 3 years ago

When attempting to build JAX on our AMD based system, with gfx908, I noticed this at the top of the bazel build output:

ROCm enabled: yes
ROCm toolkit path: /opt/rocm-4.1.0
ROCm amdgpu targets: gfx803,gfx900,gfx906,gfx1010

Does this mean gfx908 is not a currently supported target, only the ones on the list?

@proutrc Try building with --rocm_amdgpu_targets gfx908, my impression is that it should be supported by now, although I have no way of verifying this.

Maybe ask @deven-amd and @reza-amd who should know more.

The list is just what I put in at the time, however it should probably be updated to resemble what the folks at amd are officially supporting now.

proutrc commented 3 years ago

Thanks, my core issues seems to be revolving around this error:

ERROR: /gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jaxlib/BUILD:308:17: in cc_binary rule //jaxlib:rocblas_kernels.so: target '@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status' is not visible from target '//jaxlib:rocblas_kernels.so'. Check the visibility declaration of the former target if you think the dependency is legitimate

My knowledge of Bazel is pretty limited, so I am not totally sure what might cause this. I pointed my --output_user_root to here:

/tmp/rprout/jax-build/

and can see that tensorflow is getting put there (here is the specific target it says it can't see - at least the header and source files):

[rprout@login1.spock jax]$ ls /tmp/rprout/jax-build/963a11611e42043c6fc2215e83411296/external/org_tensorflow/tensorflow/compiler/xla/service/custom_call_status
custom_call_status.cc               custom_call_status.h                custom_call_status_internal.h       custom_call_status_test.cc          custom_call_status_test_c_caller.c  custom_call_status_test_c_caller.h

Not totally sure what is causing the visibility error mentioned. Any insight?

reza-amd commented 3 years ago

The BUILD failure issue you see is due to the #7480 that I will fix soon (+ the fix for rocm_amdgpu_targets). Meanwhile, you can checkout the parent of this commit (git checkout 6249d664) to have a successful build.

proutrc commented 3 years ago

@reza-amd this worked - thank you

reza-amd commented 3 years ago

Fix for ROCm build issue was submitted in PR #7633.

We have also released a docker container (currently in preview stage) for ROCm build in here. The nightly builds containers will also be added soon.

proutrc commented 3 years ago

@reza-amd

I have ran into further issues while testing this build, unfortunately. It does build successfully, but I am experiencing a runtime error that seems to revolve around JIT compilation on the target backend. I am curious if you all have seen this before? It may be worth comparing your AMD test environment to our HPC environment.

I can see the GPU from JAX, as shown below, but trying to run the simple JAX test produces this error about not being able to find the LLVM linker:

>>> import os
>>> import jax
>>> jax.devices()[0].device_kind
'Arcturus GL-XL [AMD Instinct MI100]'
>>> import jax.numpy as jnp
>>> from jax import grad, jit, vmap
>>> from jax import random
>>> key = random.PRNGKey(0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/_src/random.py", line 75, in PRNGKey
    k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32)))
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/_src/lax/lax.py", line 386, in shift_right_logical
    return shift_right_logical_p.bind(x, y)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/core.py", line 265, in bind
    out = top_trace.process_primitive(self, tracers, params)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/core.py", line 609, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/interpreters/xla.py", line 273, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/_src/util.py", line 186, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/_src/util.py", line 179, in cached
    return f(*args, **kwargs)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/interpreters/xla.py", line 322, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)
  File "/gpfs/alpine/stf007/scratch/rprout/jax-spck/jax/jax/interpreters/xla.py", line 385, in backend_compile
    return backend.compile(built_c, compile_options=options)
RuntimeError: Internal: unable to find ld.lld in PATH: No such file or directory