elixir-nx / xla

Pre-compiled XLA extension
Apache License 2.0
83 stars 21 forks source link

Support Apple Silicon #8

Closed dannote closed 2 years ago

dannote commented 2 years ago

Since XLA moved to its own Hex package, I suggest continuing the previous discussion (elixir-nx/nx#217) here.

As Apple claims, TensorFlow supports M1 since v2.5 via PluggableDevice.

Caveats

Building XLA for Apple Silicon

I haven't succeeded in building XLA yet. Here is the farthest I've been able to get so far.

I've added a new build target and bumped TensorFlow to 2.6.0:

diff --git a/Makefile b/Makefile
index 0d70b97..e277e87 100644
--- a/Makefile
+++ b/Makefile
@@ -13,10 +13,10 @@ TENSORFLOW_GIT_REPO ?= https://github.com/tensorflow/tensorflow.git

 # TODO: Should this instead be a stable version source?
 # e.g. instead use wget to download tagged releases
-TENSORFLOW_GIT_REV ?= 54dee6dd8d47b6e597f4d3f85b6fb43fd5f50f82
+TENSORFLOW_GIT_REV ?= 919f693420e35d00c8d0a42100837ae3718f7927

 # Private configuration
-BAZEL_FLAGS = --define "framework_shared_object=false" -c $(BUILD_MODE)
+BAZEL_FLAGS = --define "framework_shared_object=false" -c $(BUILD_MODE) --incompatible_restrict_string

 TENSORFLOW_NS = tf-$(TENSORFLOW_GIT_REV)
 TENSORFLOW_DIR = $(BUILD_CACHE)/$(TENSORFLOW_NS)
diff --git a/lib/xla.ex b/lib/xla.ex
index 9937e11..9b58338 100644
--- a/lib/xla.ex
+++ b/lib/xla.ex
@@ -51,7 +51,7 @@ defmodule XLA do
   defp xla_target() do
     target = System.get_env("XLA_TARGET", "cpu")

-    supported_xla_targets = ["cpu", "cuda", "rocm", "tpu", "cuda102", "cuda110", "cuda111"]
+    supported_xla_targets = ["cpu", "cuda", "rocm", "tpu", "cuda102", "cuda110", "cuda111", "macos"]

     unless target in supported_xla_targets do
       listing = supported_xla_targets |> Enum.map(&inspect/1) |> Enum.join(", ")
@@ -68,6 +68,7 @@ defmodule XLA do
         "cuda" <> _ -> "--config=cuda"
         "rocm" <> _ -> "--config=rocm --action_env=HIP_PLATFORM=hcc"
         "tpu" <> _ -> "--config=tpu"
+        "macos" <> _ -> "--config=macos_arm64"
         _ -> ""
       end

Installed the latest OpenJDK and numpy:

brew install openjdk@17 make
sudo ln -sfn /opt/homebrew/opt/openjdk/libexec/openjdk.jdk /Library/Java/JavaVirtualMachines/openjdk.jdk
echo 'export PATH="/opt/homebrew/opt/openjdk/bin:$PATH"' >> ~/.zshrc
source ~/.zshrc
pip3 install numpy

Installed the latest stable Bazel for aarch64:

curl -fLO https://github.com/bazelbuild/bazel/releases/download/4.2.1/bazel_nojdk-4.2.1-darwin-arm64
chmod +x bazel_nojdk-4.2.1-darwin-arm64
sudo mv bazel_nojdk-4.2.1-darwin-arm64 /usr/local/bin/bazel

Tried to build XLA with the new target:

XLA_TARGET=macos XLA_BUILD=true mix compile

But here it crashes with java.lang.reflect.InaccessibleObjectException. It looks like Bazel doesn't take JSR 376 into account and so doesn't support OpenJDK 17 yet.

It's even worse with OpenJDK 11 on Monterey: I got SIGSEGV somewhere in java.base.

I haven't tried to build XLA on Big Sur yet. On Monterey, I will have to wait for OpenJDK 11 to be fixed or Bazel to be ported to OpenJDK 17.

I'll track my progress here if you don't mind.

jonatanklosko commented 2 years ago

Hey @dannote! Yeah, GH Actions may take a while to support M1 and even the cloud support is not perfect either (I found one provider, but they require provisioning instances for at least 24h), that's why I didn't push it at this point. Thanks for looking into this, feel free to post updates here :)

We shouldn't need the macos target, because it's characteristic of the target machine that we can determine. That's mostly a note to myself, don't worry about this.

dannote commented 2 years ago

I found one provider, but they require provisioning instances for at least 24h

Unfortunately, that's due to legal reasons since Big Sur. Apple has added this requirement in the 3A (ii) section of macOS EULA.

dannote commented 2 years ago

Good news!

I managed to build it.

Firstly, I upgraded to macOS Monterey 12 Beta 7. I found that TensorFlow is pinned to Bazel 3.7.2, but I came up with this simple workaround:

brew install bazelisk
XLA_TARGET=macos XLA_BUILD=true USE_BAZEL_VERSION=4.2.1 mix compile

That's it! No local JDK is required.

Target //tensorflow/compiler/xla/extension:xla_extension up-to-date:
  bazel-bin/tensorflow/compiler/xla/extension/xla_extension.tar.gz
INFO: Elapsed time: 1762.020s, Critical Path: 141.25s
INFO: 9549 processes: 1867 internal, 7682 local.
INFO: Build completed successfully, 9549 total actions
INFO: Build completed successfully, 9549 total actions

Is there anything you would like to test?

jonatanklosko commented 2 years ago

Great to hear!

@seanmor5 is there anything holding us back from bumping to tensorflow 2.6 that you're already aware of, or is it worth testing?

seanmor5 commented 2 years ago

Worth testing, it's probably for the better anyway as I think it bumps CUDA support to 11.4. I can deal with fixing breaks in EXLA

jonatanklosko commented 2 years ago

Awesome, make a PR and test then!

jonatanklosko commented 2 years ago

@dannote was there any error that --incompatible_restrict_string fixed for you?

dannote commented 2 years ago

@jonatanklosko Yes, I forgot to mention, there was a bunch of errors without this flag:

ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:101:20: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:101:29: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:102:20: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:102:29: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:104:19: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD:106:23: invalid escape sequence: \/. You can enable unknown escape sequences by passing the flag --incompatible_restrict_string_escapes=false
ERROR: Skipping '//tensorflow/compiler/xla/extension:xla_extension': no such target '//tensorflow/compiler/xla/extension:xla_extension': target 'xla_extension' not declared in package 'tensorflow/compiler/xla/extension' defined by /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD
jonatanklosko commented 2 years ago

It's interesting because --incompatible_restrict_string should actually enable the error while --incompatible_restrict_string=false should help bypassing it.

seanmor5 commented 2 years ago

Aren't those escape sequences coming from the genrule which rewrites some directories? Will that flag stop the genrule from running? It's important because the genrule maps header file paths correctly

dannote commented 2 years ago

Aren't those escape sequences coming from the genrule which rewrites some directories?

Yes, they are

less -N /Users/dannote/.cache/xla_extension/tf-919f693420e35d00c8d0a42100837ae3718f7927/tensorflow/compiler/xla/extension/BUILD

     74 # This is the genrule used by TF install headers to correctly
     75 # map headers into a directory structure
     76 genrule(
     77   name = "xla_extension_headers",
     78   srcs = [
     79     ":xla_extension_dep_headers",
     80   ],
     81   outs = ["include"],
     82   cmd = """
     83     mkdir $@
     84     for f in $(SRCS); do
     85       d="$${f%/*}"
     86       d="$${d#bazel-out/*/genfiles/}"
     87       d="$${d#bazel-out/*/bin/}"
     88       if [[ $${d} == *local_config_* ]]; then
     89         continue
     90       fi
     91       if [[ $${d} == external* ]]; then
     92         extname="$${d#*external/}"
     93         extname="$${extname%%/*}"
     94         if [[ $${TF_SYSTEM_LIBS:-} == *$${extname}* ]]; then
     95           continue
     96         fi
     97         d="$${d#*external/farmhash_archive/src}"
     98         d="$${d#*external/$${extname}/}"
     99       fi
    100       # Remap llvm paths
    101       d="$${d/llvm\/include\/llvm/llvm}"
    102       d="$${d/llvm\/include\/llvm-c/llvm-c}"
    103       # Remap google path
    104       d="$${d/src\/google/google}"
    105       # Remap grpc paths
    106       d="$${d/include\/grpc/grpc}"
    107       mkdir -p "$@/$${d}"
    108       cp "$${f}" "$@/$${d}/"
    109     done
    110     """,
    111   )
dannote commented 2 years ago

I feel like I will able to build it from 54dee6dd8d47b6e597f4d3f85b6fb43fd5f50f82

dannote commented 2 years ago

Nope, I got this error at 54dee6dd8d47b6e597f4d3f85b6fb43fd5f50f82 (same as elixir-nx/nx#441):

ERROR: /private/var/tmp/_bazel_dannote/20416a7c8310bcebbb467e0a4d7b41a4/external/llvm-project/llvm/BUILD:816:11: Compiling llvm/lib/Target/AArch64/GISel/AArch64O0PreLegalizerCombiner.cpp failed: (Aborted): wrapped_clang failed: error executing command external/local_config_cc/wrapped_clang '-D_FORTIFY_SOURCE=1' -fstack-protector -fcolor-diagnostics -Wall -Wthread-safety -Wself-assign -fno-omit-frame-pointer -g0 -O2 -DNDEBUG '-DNS_BLOCK_ASSERTIONS=1' ... (remaining 59 argument(s) skipped)
external/llvm-project/llvm/lib/Target/AArch64/GISel/AArch64O0PreLegalizerCombiner.cpp:45:10: fatal error: 'AArch64GenO0PreLegalizeGICombiner.inc' file not found
#include "AArch64GenO0PreLegalizeGICombiner.inc"
         ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1 error generated.
Error in child process '/usr/bin/xcrun'. 1
Target //tensorflow/compiler/xla/extension:xla_extension failed to build
jonatanklosko commented 2 years ago

@dannote thanks for verifying, could you try over on #9?

dannote commented 2 years ago

@jonatanklosko Yes, trying it now

jonatanklosko commented 2 years ago

Resolved in #9.