Install guide | Quickstart | IPU JAX on Paperspace | Documentation
:red_circle: :warning: Non-official experimental :warning: :red_circle:
This is a very thin fork of http://github.com/google/jax for Graphcore IPU. This package is provided by Graphcore Research for experimentation purposes only, not production (inference or training).
The following features are supported:
pmap
and (experimental) pjit
;Known limitations of the project:
This is a research project, not an official Graphcore product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!
The experimental JAX wheels require Ubuntu 20.04, Graphcore Poplar SDK 3.1 or 3.2 and Python 3.8, and can be installed as following:
pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk310 -f https://graphcore-research.github.io/jax-experimental/wheels.html
For SDK 3.2, please change jaxlib
version to jaxlib==0.3.15+ipu.sdk320
.
The following example can be run on Graphcore IPU Paperspace (or on a non-IPU machine using the IPU emulator):
from functools import partial
import jax
import numpy as np
@partial(jax.jit, backend="ipu")
def ipu_function(data):
return data**2 + 1
data = np.array([1, -2, 3], np.float32)
output = ipu_function(data)
print(output, output.device())
pmap
on IPUs quickstart pjit
on IPUs quickstart Additional JAX on IPU examples:
pmap
;Useful JAX backend flags:
As standard in JAX, these flags can be set using from jax.config import config
import.
Flag | Description |
---|---|
config.FLAGS.jax_platform_name ='ipu'/'cpu' |
Configure default JAX backend. Useful for CPU initialization. |
config.FLAGS.jax_ipu_use_model = True |
Use IPU model emulator. |
config.FLAGS.jax_ipu_model_num_tiles = 8 |
Set the number of tiles in the IPU model. |
config.FLAGS.jax_ipu_device_count = 2 |
Set the number of IPUs visible in JAX. Can be any local IPU available. |
config.FLAGS.jax_ipu_visible_devices = '0,1' |
Set the specific collection of local IPUs to be visible in JAX. |
Alternatively, like other JAX flags, these can be set using environment variables (e.g. JAX_IPU_USE_MODEL
, JAX_IPU_MODEL_NUM_TILES
,...).
Useful PopVision environment variables:
POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'
PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'
The project remains licensed under the Apache License 2.0, with the following files unchanged:
The additional dependencies introduced for Graphcore IPU support are: