graphcore-research / jax-experimental

JAX for Graphcore IPU (experimental)
https://github.com/graphcore-research/jax-experimental#readme
Apache License 2.0
21 stars 2 forks source link
ai deep-learning graphcore ipu jax machine-learning
logo

:red_circle: Non-official experimental :red_circle: JAX on Graphcore IPU

Run on Gradient Continuous integration

Install guide | Quickstart | IPU JAX on Paperspace | Documentation

:red_circle: :warning: Non-official experimental :warning: :red_circle:

This is a very thin fork of http://github.com/google/jax for Graphcore IPU. This package is provided by Graphcore Research for experimentation purposes only, not production (inference or training).

Features and limitations of experimental JAX on IPUs

The following features are supported:

Known limitations of the project:

This is a research project, not an official Graphcore product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

Installation

The experimental JAX wheels require Ubuntu 20.04, Graphcore Poplar SDK 3.1 or 3.2 and Python 3.8, and can be installed as following:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html

For SDK 3.2, please change jaxlib version to jaxlib==0.3.15+ipu.sdk320.

Minimal example

The following example can be run on Graphcore IPU Paperspace (or on a non-IPU machine using the IPU emulator):

from functools import partial
import jax
import numpy as np

@partial(jax.jit, backend="ipu")
def ipu_function(data):
    return data**2 + 1

data = np.array([1, -2, 3], np.float32)
output = ipu_function(data)
print(output, output.device())

JAX on IPU Paperspace notebooks

Additional JAX on IPU examples:

Useful JAX backend flags:

As standard in JAX, these flags can be set using from jax.config import config import.

Flag Description
config.FLAGS.jax_platform_name ='ipu'/'cpu' Configure default JAX backend. Useful for CPU initialization.
config.FLAGS.jax_ipu_use_model = True Use IPU model emulator.
config.FLAGS.jax_ipu_model_num_tiles = 8 Set the number of tiles in the IPU model.
config.FLAGS.jax_ipu_device_count = 2 Set the number of IPUs visible in JAX. Can be any local IPU available.
config.FLAGS.jax_ipu_visible_devices = '0,1' Set the specific collection of local IPUs to be visible in JAX.

Alternatively, like other JAX flags, these can be set using environment variables (e.g. JAX_IPU_USE_MODEL, JAX_IPU_MODEL_NUM_TILES,...).

Useful PopVision environment variables:

Documentation

License

The project remains licensed under the Apache License 2.0, with the following files unchanged:

The additional dependencies introduced for Graphcore IPU support are: