google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Hardware acceleration on Apple Silicon with Metal plugin #8074

Closed dannote closed 1 year ago

dannote commented 2 years ago

Hello!

I'm looking for a way to accelerate the XLA compiler for Apple's M1 (elixir-nx/nx#490). Apple provides a PluggableDevice plugin with the METAL platform, but it doesn't include an XLA backend for it yet.

Do you have any plans to target M1's GPU?

hawkinsp commented 2 years ago

We have no plans ourselves at present, and I'm not aware of any from the XLA folks themselves. CUDA and ROCm are their focus.

However, it's not out of the question that there might eventually be a compilation path from HLO (actually, the mHLO MLIR dialect, which is a close cousin of XLA HLO) to Metal shaders from IREE (https://github.com/google/iree). (I believe IREE targets SPIR-V at the moment, but not yet Metal.)

nicholasjng commented 2 years ago

Just out of curiosity, how would one go about implementing a GPU backend for Metal in XLA? I checked out the XLA source a little bit, and while it looks fairly complicated (though I also have not worked with C++ in a while), do you think it could be a feasible project for someone to implement some GPU-accelerated XLA operations for the M1 platform? (I've been looking to get into GPU programming for a while now, which is the reason for my interest).

If the hypothetical path from IREE to Metal you mentioned is much more realistic, though, then it might not make sense to go ahead on this. I would be happy about a quick expert opinion :)

hawkinsp commented 2 years ago

Well, the path to get something working is perhaps shorter than you might think:

Using jax from Github head:

pip install --upgrade jaxlib
pip install iree-compiler-snapshot iree-runtime-snapshot -f https://github.com/google/iree/releases
JAX_ENABLE_MLIR=1 JAX_PLATFORMS=iree python your_jax_script.py

So the JAX->IREE path exists, at least in an early form.

So if I wanted to try this, I might try looking at that IREE issue.

nicholasjng commented 2 years ago

Thanks for the hint! I need to take a closer look at IREE, then. Seems they want it rewritten in Objective-C for better performance and smaller size, so I will check up on how they implement other C backends first.

benjaminpope commented 2 years ago

This would be awesome!

nicholasjng commented 2 years ago

I looked at IREE a little bit, and it seems out of my league (at least for the moment and given my free time budget), so I wanted to ask some follow-up questions about developing a new XLA backend in the case of Metal (if that's ok):

Apologies if this is obvious, I'm not really experienced with compilers (to my great regret, it's a very interesting subject to me). I want to progress on this, though, but right now the barrier is a little too steep. Thank you for your consideration!

stellaraccident commented 2 years ago

FYI - news from one of IREE's users/contributors relevant to this: https://nod.ai/pytorch-m1-max-gpu/

They were successful at adapting and tuning IREE to the case. This is still quite early work but promising.

powderluv commented 2 years ago

yes JAX+IREE (and nod.ai tuned SHARK) on Apple Silicon is a high priority for us at Nod.ai. We hope to have the JAX pipeline flushed out once a few more upstream pieces land. But happy to help with anything if anyone else is trying it.

hawkinsp commented 2 years ago

For some definition of "works", JAX-on-Metal via IREE/Vulkan and MoltenVK seems to work right now on Mac:

Rough steps:

Profit:

JAX_PLATFORMS=iree JAX_IREE_BACKEND=vulkan python ...

now runs ops on GPU.

However, do not expect the result to be faster than the current CPU support or particularly complete yet! There has been zero performance tuning done so far and there are many known bugs.

powderluv commented 2 years ago

@hawkinsp the Build from source step shouldn't be required any longer. The nightly builds with VULKAN turned on. Please let me know if that is not the case - we have already upstreamed all the apple M1 pieces.

You will need to pass "extra_args" to IREE to set the right Apple M1 target triple: https://github.com/nod-ai/SHARK/blob/1186d7c58e6046aea6a6115c608dbd77728e7aca/shark/iree_utils.py#L93-L96

If you have a nice JAX model you would like optimized for the M1 Ultra please send me link and we can tweak the backend codegen for that.

hawkinsp commented 2 years ago

Thanks! You're right, everything needed is in an iree nightly build. And passing the right target triple meant I got meaningfully better performance for the f32[1000, 1000] x f32[1000, 1000] matmul I was benchmarking via the Vulkan path, so that's pretty promising.

If you're interesting in performance tuning, I'm sure we can cook up a transformer model or something that might be fun to optimize!

powderluv commented 2 years ago

Yes would love to show a Jax transformer on the M1 Max/Ultra. Please let me know if you have a sample somewhere and we will get to work 😀

stellaraccident commented 2 years ago

It'd be really nice to shave any rough edges so this is good-by-default. Minimally, it would be nice to have a way to auto-detect/set host optimized compiler flags.

stellaraccident commented 2 years ago

Thanks! You're right, everything needed is in an iree nightly build. And passing the right target triple meant I got meaningfully better performance for the f32[1000, 1000] x f32[1000, 1000] matmul I was benchmarking via the Vulkan path, so that's pretty promising.

If you're interesting in performance tuning, I'm sure we can cook up a transformer model or something that might be fun to optimize!

That'd be really nice to have a blessed thing to iterate on (for both a training and inference workload). Hyperfocusing like that is often step one on a new platform, and getting some scripts and such in place so that the folks working at the lower level don't need to think about the "ML part" can really accelerate the work.

Birch-san commented 2 years ago

thanks for everyone's effort so far in getting jax to work with M1 GPU.

I tried this out today on the dalle-playground jax model, but I got an error:
failed to legalize operation 'mhlo.scatter' that was explicitly marked illegal (full details here).

is this currently unimplemented? is there an option I can use to lower the operations differently?
I didn't see any other options in Flow/Transforms/Passes.cpp or MHLO/Passes.cpp that looked like capable of fixing this.

I believe I installed everything correctly. I installed MoltenVK, rebooted, then:

# on commit 345cc19949273cc414d94e6f13d0620b780af465
git clone https://github.com/google/jax.git
cd jax
# get jaxlib deps
pip install numpy wheel
# build jaxlib wheel
/Users/birch/anaconda3/envs/torch-nightly/bin/python ./build/build.py \
--noenable_mkl_dnn \
--noenable_cuda \
--noenable_tpu \
--noenable_nccl
# install jaxlib wheel
pip install /Users/birch/git/jax/dist/jaxlib-0.3.11-cp39-none-macosx_11_0_arm64.whl
# install jax
pip install -e .
# install today's iree release candidate
pip install iree_compiler iree_runtime -f https://github.com/google/iree/releases/tag/candidate-20220606.161

_to confirm sanity, I have also tried running the model on-CPU with the jax/jaxlib I built — i.e. by removing the JAX_PLATFORMS and JAX_IREE_BACKEND enviroment variables — and it does indeed still work fine in legacy mode. so I think I at least built+installed jax/jaxlib correctly._

@powderluv if you're looking for a model to optimize the backend against, this dalle-playground is pretty popular right now (powers the dalle-mini demo, but also has capability to run dalle-mega, which is not yet featured in the demo).

powderluv commented 2 years ago

oh this looks exciting. Can you please share your entire command history after installing the iree_compiler/runtime ? Can I run it from the command line instead of VScode ?

Birch-san commented 2 years ago

yeah, definitely possible to run without VSCode.

my command history for setting up dalle-playground was pretty messy. starting with a machine that's never done Python development, my steps today have been moreorless this:

Setup Anaconda

Install Anaconda via command-line arm64 installer, then:

(optional) install PyTorch nightly (to get M1 GPU support)

conda update -n base -c defaults conda
conda create -n torch-nightly python=3.9
pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
conda config --set auto_activate_base false
echo 'conda activate torch-nightly' >> ~/.zshrc
# now open a new shell, or source your zshrc in this shell

Get MoltenVK, new jax+libjax+iree

Run the commands from my previous comment, to:

We'll also need pass to iree extra_flags mentioned in @powderluv's example

In your local jax repository, modify jax/jax/_src/iree.py to pass some extra flags:

    extra_args = []
    # extra_args=["--mlir-print-ir-after-all"]
    if platform.system() == "Darwin" and platform.machine() == "arm64":
      extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin21.5.0"]
+     # my hardcoded additions to iree.py:
+     extra_args += ["--iree-flow-demote-i64-to-i32",
+     "--iree-vulkan-target-triple=m1-moltenvk-macos",
+     "--iree-llvm-target-cpu-features=host",
+     "--iree-mhlo-demote-i64-to-i32=false"]

Obviously this is a terrible way to pass these flags. If you can figure out how to pass the flags down from the dalle-playground code, then tell me how 🙃.
dalle-playground relates to your local jax via an "egg-link" (~/anaconda3/envs/torch-nightly/lib/python3.9/site-packages/jax.egg-link), so you can make changes like this to the jax source, no reinstall necessary.

Setup dalle-playground

git clone https://github.com/saharmor/dalle-playground.git
cd dalle-playground

Backend

cd backend

Now install & run:

Install

python3 -m venv dalle
source ./dalle/bin/activate
pip install --upgrade pip
# warning: first ensure you have on your PATH: cargo + a beta release of rustc
pip install -r requirements.txt

Beware: the pip install -r requirements.txt was a bit challenging; one of its dependencies tried to compile-from-source.
It wanted rustc and cargo to be on the PATH.
Stable rustc wasn't new enough; I had to install a beta. These worked:

rustc 1.62.0-beta.3 (a5cf77ca6 2022-06-02)
cargo 1.62.0-beta.3 (4751950cc 2022-05-27)

If you're using macOS with nix as your package manager, do tell and I can share the config I used to install Rust. It was horrible.

Run

Start the backend (which exposes a web interface on port 8080):

# if you want to try the known-good CPU-only mode, then remove these env vars
JAX_PLATFORMS=iree JAX_IREE_BACKEND=vulkan python app.py 8080

A successful launch will print something like this:

--> Starting DALL-E Server. This might take up to two minutes.
wandb: Currently logged in as: anony-moose-278165. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.12.17
wandb: Run data is saved locally in /Users/birch/git/dalle-playground/wandb/run-20220606_232449-1pjbrycr
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run cerulean-glitter-21
wandb: ⭐️ View project at [redacted]
wandb: 🚀 View run at [redacted]
wandb: WARNING Do NOT share these links with anyone. They can be used to claim your runs.
wandb: Downloading large artifact wzoooa1c:latest, 1672.79MB. 7 files... Done. 0:0:0
wandb: Downloading large artifact wzoooa1c:latest, 1672.79MB. 7 files... Done. 0:0:0
--> DALL-E Server is up and running!
 * Serving Flask app 'app' (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off
INFO:werkzeug: * Running on all addresses (0.0.0.0)
   WARNING: This is a development server. Do not use it in a production deployment.
 * Running on http://127.0.0.1:8080
 * Running on http://[redacted]:8080 (Press CTRL+C to quit)

Frontend (easy-peasy, NodeJS stuff)

cd frontend
npm i
# deploys frontend on port 3000
npm start

Ensure your backend is ready (i.e. wait for it to have printed "--> DALL-E Server is up and running!"), then navigate to:
http://localhost:3000/dalle-playground?backendUrl=http://127.0.0.1:8080

Type a query (this one succeeded because I ran without GPU):

image
Birch-san commented 2 years ago

I've filed an issue with iree to improve support for the mhlo.scatter operation used in this jax model.

in the meantime: is there a way to utilise more of the CPU? I think I'm noticing the same thing as https://github.com/google/jax/issues/5022 — when I ask the jax model to generate 1 image on-CPU, it takes a few minutes (2mins with dalle-mini, 27mins with dalle-mega) but my CPU usage is only 117%. moreover it seems to be utilising efficiency cores without maxing out performance cores.
is there a way to set thread affinity so that it prefers to run on performance cores? and a way to tell it to use all of them?

hawkinsp commented 2 years ago

@Birch-san Yes, the MHLO op support in IREE isn't complete and has some known deficiencies. see for example the hotlist of issues: https://github.com/google/iree/issues?q=is%3Aissue+is%3Aopen+label%3Aintegrations%2Fmhlo many of which come from running the JAX tests. But it's early days yet! I'm sure the IREE folks would appreciate you filing issues as you discover them, assuming they aren't already known.

For using more CPU cores: XLA doesn't do a whole lot of inter-operation parallelism on CPU at the moment. You can add explicit parallelism at the JAX level (e.g., using pmap or pjit). I'm not sure what IREE-on-CPU's parallelism status is. CPU has been a bit neglected historically in XLA but there is a bunch of work in progress from a number of compiler folks that gives me hope that it will start to make more rapid progress soon.

Birch-san commented 2 years ago

yeah, happy to file any issues I find 🙂 but I can't get past that first one, so I think that's all I can report for now on the IREE side.

okay, I suppose pursuing parallelism on the JAX level would require me to get deep into rewriting the model (not my specialty).
I have spare CPU cores, so I can run multiple tasks concurrently. with dalle-mega, I can generate 1 image on-CPU in 27 mins (117% CPU usage), or submit 4 concurrent 1-image tasks (450% CPU) to get 4 images in 43 mins (11 mins per pic).
a good improvement, but would be so nice to get the GPU involved too!

stellaraccident commented 2 years ago

I'm traveling and a bit slow to respond on details, but I think there may be an M1 specific bug in IREE with respect to parallelism when running in the CPU (most of this thread seems to be about GPU so I'm having a but of a hard time following exactly what is getting attempted): we had a lot of trouble with cpuid ok M1 since it uses too much compile time vs runtime heuristics for reporting its information and universal builds break all kinds of assumptions. I think the result may be that the CPU thread pool is being set up by default as if there was only one core. There is a flag to override this but I don't have access to my M1 while traveling and can't look into this right now. In general, IREE on CPU should be heavily threading for both intra and inter op parallelism.

Would you mind filling an issue on the iree side door the CPU parallelism and reference this one? There are others beyond me who may be able to look at this but they aren't monitoring Jax issues.

stellaraccident commented 2 years ago

I am pretty sure that M1 universal builds of IREE are taking the fallback path for CPU topology detection: https://github.com/google/iree/blob/be167a62f8872597eac1b72e26b4c62e291bfd5c/runtime/src/iree/task/topology_cpuinfo.c#L32

This is defaulting to one physical core. This was done to bring up M1 at all and we need to fix this.

Birch-san commented 2 years ago

sure, I'll file an issue about CPU parallelism too. thanks for your theories. 🙂

stellaraccident commented 2 years ago

Since you are using the python API, you can set runtime flags like this: https://github.com/google/iree/blob/be167a62f8872597eac1b72e26b4c62e291bfd5c/runtime/bindings/python/tests/flags_test.py#L20

The specific flag is --task_topology_group_count=8 (https://github.com/google/iree/blob/b95c520c4c18c24abee62bc86288eeec37e53ae2/runtime/src/iree/task/api.c#L50)

Birch-san commented 2 years ago

@stellaraccident I've raised https://github.com/google/iree/issues/9368 to report the "1 CPU core" problem.

thanks for explaining how to pass flags.

I tried placing this at the start of app.py:

from iree import runtime as rt
rt.flags.parse_flags("--task_topology_group_count=8")

but it didn't make any difference; CPU usage for generating 1 image in dalle-mini was 125% either way, and took 139secs either way.

is there a breakpoint I can place anywhere in the iree python to determine what group count it ends up using? or a Python expression I can add to look it up?

or is there a distribution of iree with debug symbols, so I could check what's returned by the functions in the native code?

hawkinsp commented 2 years ago

@Birch-san Can you confirm that you're actually running the IREE backend (JAX_PLATFORMS=iree or similar) when measuring performance? Note you will probably hit the mhlo.scatter problem you reported above.

Birch-san commented 2 years ago

@hawkinsp oh, that's embarrassing.
you're right -- I was not using JAX_PLATFORMS=iree for my on-CPU test. and as you say: if I do turn it on, it explodes with the mhlo.scatter complaint.

presumably that explains why --task_topology_group_count=8 had no effect.

@stellaraccident I guess I'll need to update the iree issue. it sounds like it's still an issue that ought to exist, but I think my repro via dalle-playground is not valid, and the detail about task_topology_group_count not working is also invalid.

stellaraccident commented 2 years ago

We're happy to take the issue: this should be more automatic.

But this is what we need to focus on fixing to unblock: https://github.com/google/iree/issues/9361

Birch-san commented 2 years ago

I tried the proposed patch for https://github.com/google/iree/issues/9361. that gets me onto the next error with using the IREE backend to run dalle-playground: https://github.com/google/jax/issues/11166 (MLIR translation rule for primitive 'remat_call' not found for platform 'iree')

Robokan commented 1 year ago

Curious what the current status of this is. If I install JAX on a Mac M1 via:

pip install jax

Is it running on the GPU by default? Or do I need to do something to enable it?

leonard-gleyzer commented 1 year ago

Am also curious about this. JAX M1-series GPU support would be incredibly helpful.

stellaraccident commented 1 year ago

We're working on a better layered plugin architecture for Jax, and plugging iree in more directly is pushing that forward. This will make it easier for third parties to extend real support in this direction.

Despite early results, I do suspect that a real/solid/performant backend for the M1 GPU is quite a bit of work, and my goal is to enable the community to service this area of need based on some starting points that work. It would also be ideal to involve Apple themselves.

Stand by for a better plugin.

nicholasjng commented 1 year ago

I think the decision to pluck out XLA from the Tensorflow repo might spark some development in that regard. I have been trying to get into IREE and XLA after work for some time now, but the complexity is very high and thus there is a more or less substantial barrier of entry for a layperson like me.

There was some documentation in TF XLA on how to implement a new backend, but it was rather abstract. Maybe there could be some more in-depth documentation walkthrough on how to add new backends (or even frontends) in the future? I saw that MLIR has done some work in that regard with its toy dialect example. I think it would be a valuable thing to add - would love to get involved in it as well if that's feasible!

stellaraccident commented 1 year ago

Good feedback - thank you. We're rolling up the sidewalks here at the Google offices for the year, so to set expectations, this is likely a 2023 discussions. If there is demand, it would be good to get this on the OpenXLA calendar to talk through... That could be a good forum to connect more of the dots you are seeking and get more things flushed out publicly.

stellaraccident commented 1 year ago

@skye is there any chance of releasing more of the documentation/discussion about the plugin work you are doing? Maybe connect with Thea on the OpenXLA side to create a space for this kind of discussion?

philipturner commented 1 year ago

Just to note, the CPU on Apple silicon is very powerful. The AMX blocks provide up to 2 TFLOPS of single-precision matmul power on the M1 Max, and 1 TFLOPS on the base M1. Compare that to 8 TFLOPS for M1 Max GPU and 2 TFLOPS for base M1 GPU. If you don't fully optimize matrix multiplications for the GPU, it will be equal slower than CPU. Hence, no point in GPU acceleration.

The Apple GPU architecture, by itself, it bad at matrix multiplications. A generic matmul library called DLPrimitives reached 25% ALU utilization, while pre-A14 MPSMatrixMultiplication also reached that. Compare that to A14/M1 and later, where MPSMatrixMultiplication jumped to 80% ALU utilization. Apple added a special intrinsic called simdgroup_matrix which allows fast matrix multiplications. We would need to expose this intrinsic to the SPIR-V backend compiler. Otherwise, 25% times 2.6 TFLOPS = 0.65 TFLOPS, which is less than 1.0 TFLOPS from M1 CPU.

It would also be ideal to involve Apple themselves.

From personal experience, I would not count on Apple helping us out. I also think this effort is mostly wasted anyway, because all the backend porting will become obsolete soon. It might be good to read over Modular AI's website, then wait to integrate their platform into JAX in the future.

stellaraccident commented 1 year ago

This is what I was alluding to: without exposing the intrinsics and using them, there is really no way to great. I haven't checked recently but I generally assume that such things are inaccessible without the vendor having taken steps to make them so or worked jointly on the solution.

Well aware of modular. If they produce something better on the dimensions we care about, you certainly won't see me pushing back on leveraging it. I'm reserving my own judgment until there is more than marketing to evaluate.

philipturner commented 1 year ago

Your opinion about modular does make sense. Also, it will be difficult to, once you expose the intrinsics, ensure they're always used properly to reach >38% ALU utilization (1.0/2.6). You'd basically be rewriting MPSGraph from scratch, including Winograd convolutions and the likes. That's a massive time investment.

stellaraccident commented 1 year ago

It sometimes not that bad, but it is very hard to chart a path to great on that stuff without the hardware maker's involvement. These things are always hard to hold right -- and that is when you have a great datasheet, docs, etc.

philipturner commented 1 year ago

on that stuff without the hardware maker's involvement.

And this is exactly why so many AI and HPC applications are vendor-locked to NVIDIA. Their employees actively assist with optimizing several frameworks for CUDA. In contrast, with AMD and Apple it's somewhat more "you're on your own" (HIPify and SYCLomatic seem designed to address this). The solution isn't to get 4x as many engineers (from NVIDIA + AMD + Intel + Apple) performing 4x the optimization.

stellaraccident commented 1 year ago

We've had substantially better experiences than "4x" with community support for AMD. There have been more developments than this which will hopefully be written up soon: https://nod.ai/shark-rdna3-sd/

It took about 6 weeks for AMD and a contractor to adapt code generation solutions to get to quite compelling performance (there's more actively discussed on discord that doesn't seem to have written official statements). The best experiences do come from vendor engagement, but ideally that involves last mile level porting vs redevelopment. As an example, AMD seems to have made extensions to their driver to enable cutting edge support to get there: https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mlir-iree (ie. This seems to always be better as a partnership with the hardware vendor to ensure that needed facilities are exposed).

philipturner commented 1 year ago

AMD recently has been catching up to NVIDIA. Has the same effort also been puts forth with Intel’s Arc GPUs, and ARM’s Mali GPUs? As a more contrived example, there are AI processor startups with very small teams. They can only afford to invest time optimizing their hardware for (1) IREE, (2) PyTorch, (3) TensorFlow, (4) another less popular framework (e.g. S4TF). If I were them, I would choose PyTorch and invest remaining person-hours into maintaining the PyTorch backend. That leaves other frameworks untouched.

We've had substantially better experiences than "4x" with community support for AMD

I was not thinking 4x quality, rather 4x the amount of person-hours. You shouldn’t need to hire a contractor just to make your company’s hardware run optimally, on the newest popular frontend or IR.

stellaraccident commented 1 year ago

I think this all may be further along on these axes than you are thinking, but I can really only point to publicly released statements (or go further off topic and reassemble discord threads and contributions). These classes of parts are very much in the space of enhancements to one applying to all, especially those in the same broad API family -- which covers a lot at this point.

In my mind, this really is about API family + access to key intrinsics/extensions necessary for perf. Many of these are quite close on that category and get a lot for free. The further a part is on its own on that, the more expensive it will be and the higher the likelihood that only the vendor can do the work to get it there. Which brings us back to the original topic -- I suspect that Apple GPUs will be harder than expected to make great and probably require movement from Apple on programmability to get all the way. May be over estimating that but it's my default assumption for a closed platform like this.

philipturner commented 1 year ago

I will agree to disagree; your comment does have sound reasoning. I can cite evidence that Apple’s Metal team is extremely understaffed, and can barely maintain their own documentation or respond to developers. Choosing to work on IREE may not be in their interests.

Besides Apple, I may be one the few people with the capability to optimize this for the Apple GPU. Just check out the “Lumen & Nanite on MacOS” thread on Unreal Engine forums, or the failed AlphaFold 2 port, to see proof of this capability. However, I have other priorities so that leaves no one optimizing for Apple.

powderluv commented 1 year ago

This gives me a déjà vu of the discussion at

https://github.com/pytorch/pytorch/issues/47702#issuecomment-1018668822

But just to keep it to real code checked in (and not marketing hype or second guessing companies staffing etc) - we just contributed Winograd convs to IREE. It took us about 6 weeks.

I hope we keep the discussion to technical implementations of the issue at hand.

philipturner commented 1 year ago

@powderluv I don't see how the comment you specifically linked is completely relevant to this thread. I do take it personally when someone exposes my own past mistakes, seemingly in order to discredit my argument. Would you mind editing the link to point to something more technical, such as this comment, or remove it entirely? I can also refrain from expressing my viewpoint any further.

I don't assume you meant anything bad, but that PyTorch thread was extremely heated and a magnet for questionable behavior. There's a lot of hidden context that isn't particularly constructive to resurface. Best regards 😄

philipturner commented 1 year ago

second guessing companies staffing etc)

Perhaps the proper context should be given: https://developer.apple.com/forums/thread/698037?answerId=731108022#731108022

@philipturner I think your frustration is justified. Apple's developer relations are unfortunately a mess. And it's not a criticism to the fine people who do all the hard work in the background and occasionally help us on these forums, but the obvious lack of structure and ownership in these matters. Lack of updates to the Metal Feature Set tables is just one symptom of a wide systemic problem. For example, the Metal Shading Language is very difficult to use as a reference tool due to subpar formatting and lack of hyperlinks. The API documentation is also lacklustre, incomplete and difficult to navigate. Forum communication is almost non-existent. It would be great if Apple considered creating a role dedicated to improving these aspects because it seems like this is something nobody really feels responsible for.

Posted 2 months ago by jcookie

Another person noted some limitations faced in Apple-to-developer relations. I could have worded previous comments more politely, but I'm not trying to discredit Apple developers who work on Metal. My point was, these amazing people face real-world constraints. Enhancing IREE requires investing nonzero time and money, which could be spent elsewhere.

That's where my statement about Modular comes in. I am not an employee of them, and they do not endorse my viewpoint. They are in a very early stage, and they may not produce anything usable until several years. I suggested that Apple might view Modular, not IREE, as the solution to fragmentation in the AI industry. If they believed that, they might see contributing to IREE as effort spent in vain.

powderluv commented 1 year ago

Didn't mean to dig into anything from the past. Just want to say let's do what it takes to make JAX run well on Apple silicon. Thanks for being involved in the process.

philipturner commented 1 year ago

No worries. Thanks for being considerate.

arpan-dhatt commented 1 year ago

Instead of directly trying to produce Metal code for operations in XLA, what would be the viability of translating a jitted JAX call to an Metal Performance Shaders Graph invocation? MPS Graph has a number of features of XLA, although I'm not sure to what exact extent they overlap. This approach would have the benefit of using whichever intrinsics Apple's developers.

As I understand, there is currently no PluggableDevice support in JAX like TF, which would significantly increase the workload to see this task through.

Would this be a viable approach to get good GPU performance of JAX on Mac GPU's, ignoring how not using XLA might reduce the maximum possible performance? Are there any technical or design choices in JAX that would make such a task impossible (things like vmap, I'm not entirely sure)?