elixir-nx / nx

Multi-dimensional arrays (tensors) and numerical definitions for Elixir
2.66k stars 194 forks source link

(Experimental) Integrate Metal PjRt plugin #1504

Open jonatanklosko opened 5 months ago

jonatanklosko commented 5 months ago

This integrates the PjRt plugin from the jax-metal for running on the Apple GPU. To test it, one can set client: :mps on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove the cache/ directory).

Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.

Issues

For tracking purposes, here is a list of the Metal plugin issues reported upstream:

Crucial

Not implemented

Edge cases

All issues: link.


Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.