NOTE: This repository is currently under construction. Please stand by for this message to be removed before considering it usable.
The IREE JAX API provides a compiler and runtime bridge between JAX and IREE for the purpose of allowing programs to be extracted and compiled from JAX for deployment with IREE, without a Python or JAX dependency.
In order to understand how this works, it is important to have a working knowledge of what JAX and IREE programs are (which we discuss after the example).
In order to introduce the concepts, we present a sample program which directly compiles to IREE. It doesn't do anything interesting but is a place to write comments and explanations.
# Everything that is needed for basic use can be accessed from the Program
# class. Some additional imports may be needed for more advanced cases.
from iree.jax import Program
import jax.numpy as jnp
from collections import namedtuple
# Host-side arrays and trees of arrays can be mirrored into an IREE Program.
# All structure and any aliasing between trees is retained in the program.
x = jnp.ones((3, 4), jnp.float32) * 4.0
b = jnp.ones((3, 4), jnp.float32)
Params = namedtuple("Params", "x,b")
params = Params(x, b)
class TrivialKernel(Program):
"""By sub-classing `jax.iree.Program`, you create an IREE program which can
be compiled (and executed in the host process) by instantiating it.
The program's name is by default the snake_case equivalent of the class name
with any "Program" suffix removed. It can be specified explicitly via:
class MyProgram(Program, export_name="foobar"):
"""
# Globals are created by giving names in the class to arrays and trees of
# arrays from the host Python session. When accessed on an instance, they
# retain their structure but access values in the exported program.
_params = params
# The above is the default, sugared form of the more general mechanism
# for exporting a global. Using this form allows you to create uninitialized
# globals or annotate them as immutable (the compiler will generally figure
# this out for you).
# _params = Program.export_global(params, initialize=True, mutable=True)
# Sometimes it is useful to give a name to an aliased part of a tree. This
# is fully allowed. The first statement in the class which exports any
# particular leaf will define its characteristics
# (initialization/mutability/name). Subsequent exports will create aliases,
# just like in the original Python session.
_x = params.x
# Any function defined without annotation becomes a public function in the
# IREE program with an input signature derived from the declaration and
# outputs derived by tracing. We call these "Program Procedures".
# This function just provides an accessor to get the tree of _params.
def get_params(self):
return self._params
# Here we see how to specify an input signature. The `Program.like` helper
# is just producing a tree of `jax.core.AbstractValue` representing some
# reference from the host program. This is an easy way to represent such
# details, but users are also free to constract AbstractValue trees
# themselves.
def run(self, multiplier=Program.like(x)):
# When tracing a public function, the primary thing you can do is call
# immutable kernel functions using combinations of function arguments and
# program globals.
result = self._linear(multiplier, self._params.x, self._params.b)
# Program globals can be updated directly via assignment. This is a sugared
# form of the more general `Program.store_global()` helper.
self._x = result
return result
# Public functions can also accept arbitrary trees.
def set_params(self, new_params=Program.like(params)):
# And trees of globals can be updated in one assignment.
self._params = new_params
# "Kernel" functions are basically equivalent to `jax.jit` but specially
# wrapped so that they can exist as members of a Program. They act like
# staticmethods (i.e. they do not take a `self`) and they can only operate on
# arguments or host process state that is legal for `jax.jit`. Think of them
# as private static functions of the program, and by convention we name them
# with a leading underscore.
@Program.kernel
def _linear(m, x, b):
return m * x + b
# Instantiating a program will trace it and invoke the `ireec` compiler on it.
# keyword arguments control compilation behavior. The defaults should be
# reasonable for running directly on the host.
m = TrivialKernel()
# You can inspect the MLIR module which was extracted.
print(Program.get_mlir_module(m))
# While the primary purpose of extracting a program is to run it *elsewhere*,
# you can't spell "ML" without "debugging", and instantiated modules are fully
# usable interactively. Under the hood, the compiled artifact is loaded into
# the IREE runtime and wrappers are created to dispatch to named public
# functions.
print("Initial params:", m.get_params())
# Stateful updates can be made.
update = jnp.ones_like(x)
print("Run:", m.run(update))
# And inspected.
print("Updated params:", m.get_params())
# You can save off the compiled artifact to run elsewhere.
Program.get_compiled_artifact(m).save("/tmp/some_file.vmfb")
The challenge with talking about extracting a JAX program is that in all generality, a JAX program is bounded only by what can run in the host Python interpreter. What we are looking for is a simpler definition that allows a useful set of standalone programs to be constructed and that meshes well with the programming model employed by typical JAX developers.
The components that are the most interesting towards this end are:
jax.jit
functions: stateless functions mapping inputs to outputs, natively
represented as JAXPR IR and universally convertible to MHLO IR.There are of course many more details than this, but these are the components we will assemble.
An IREE program is:
ireec
compiler, targeting a number or architectures and
devicesvmfb
(VM Flatbuffer) for direct execution by the IREE
runtime.As mentioned above, "Program Procedures" are the public functions that can be invoked on a compiled Program. Today, they are implemented with a limited tracer which defines them by run. A Python based compiler is under development to allow more sophisticated procedures to be represented.
The following are allowed in a traced procedure (this is restricted in order to maintain a small surface area for a future compiler-based approach -- in reality, there are many ways to hack the current system to do more):
self
.parent.child
) or indexing (i.e. parent[0]
).Program.store_global
Program.print
(TODO)These are WIP instructions for getting a functioning development setup. We aim to improve this and make it more turnkey over time.
Pip installable releases are not yet available. However, this project is pure Python and can be installed locally for development (note that this pulls from IREE pre-release snapshots):
python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html
Note that in order to function the version of MLIR and MHLO used in the installed jaxlib must be syntax compatible with the versions used in IREE. For releases, we synchronize these, but for development it can drift and cause errors.
The easiest way to ensure this is to pin the JAX tensorflow version to the version that IREE was compiled with and follow the JAX instructions to build jaxlib locally. See:
You may need to manually pip uninstall
the automatically installed jaxlib.
lit tests/
Sometimes you will want to build with a local IREE. From IREE's build directory:
source .env && export PYTHONPATH
You may need to pip uninstall the automatically installed
iree-compiler-snapshot
and iree-runtime-snapshot
packages.
For IDE integration, you may just want to copy IREE's .env
file to the
root of this repo if working in this mode.