This integrates the PjRt plugin from the jax-metal for running on the Apple GPU. To test it, one can set client: :mps on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove the cache/ directory).
Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.
Issues
For tracking purposes, here is a list of the Metal plugin issues reported upstream:
Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.
This integrates the PjRt plugin from the
jax-metal
for running on the Apple GPU. To test it, one can setclient: :mps
on EXLA backend/compiler. Since the plugin is loaded as a separate dynamic library, it can be tested without any changes to XLA (just make sure to remove thecache/
directory).Certain computations can already be run, but the plugin is still very much incomplete. This PR is a room for experimentation and is meant to track the plugin progress. I reported a number of issues upstream, comments in the code point to those. In a few places I applied workarounds as temporary solutions or just to avoid VM crashes (segfaults), those are marked with a TODO.
Issues
For tracking purposes, here is a list of the Metal plugin issues reported upstream:
Crucial
Not implemented
Edge cases
All issues: link.
Note: this PR is against the jk-s32 branch, which changes the default integer precision to 32 bits. This is a planned change (#1491), but it's not integrated yet to avoid conflicts with other work in progress.