iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

"RuntimeError: Unknown backend iree" #77

Open WoongQ opened 1 year ago

WoongQ commented 1 year ago

Hello, I'm trying to install iree-jax to test GPT-2 on IREE. After running python -m pip install -e '.[test,xla,cpu]' -f https://openxla.github.io/iree/pip-release-links.html, I built jaxlib from source. However, when I run lit -v tests/, I get a RuntimeError with the message "Unknown backend iree". This also happens when running models/gpt2/test_jax.py. Did I miss something during the setup process? Your help would be greatly appreciated. I have attached the error log below.

Using pure python filecheck: /home/woongq/jax/bin/filecheck
-- Testing: 5 tests, 5 workers --
FAIL: IREE_JAX :: program/trivial_kernel.py (1 of 5)
******************** TEST 'IREE_JAX :: program/trivial_kernel.py' FAILED ********************
Script:
--
: 'RUN: at line 15';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/trivial_kernel.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/trivial_kernel.py
--
Exit Code: 2

Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command stderr:
WARNING:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002429485321044922 sec
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
WARNING:jax._src.dispatch:Finished tracing + transforming jit(broadcast_in_dim) in 0.0002300739288330078 sec
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(broadcast_in_dim) in 0.001964092254638672 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(broadcast_in_dim) in 0.012798309326171875 sec
WARNING:jax._src.dispatch:Finished tracing + transforming fn for pjit in 0.0004911422729492188 sec
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[3,4]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
WARNING:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(fn) in 0.0016129016876220703 sec
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
WARNING:jax._src.dispatch:Finished XLA compilation of jit(fn) in 0.0081329345703125 sec
DEBUG:iree_jax:Create new Program subclass: trivial_kernel
DEBUG:root:DEFINE PY_ONLY: _linear = <Exportable Pure Func: <function TrivialKernel._linear at 0x7f91ee93ce50>>
DEBUG:iree_jax:def_global_tree: array _params$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: array _params$1=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=Params(x=ConcreteArray(ExportedGlobalArray(@_params$0 : tensor<3x4xf32>), dtype=float32), b=ConcreteArray(ExportedGlobalArray(@_params$1 : tensor<3x4xf32>), dtype=float32))
DEBUG:iree_jax:def_global_tree: array _x$0=(3, 4):dtype('float32')
DEBUG:iree_jax:def_global_tree: new tree=ExportedGlobalArray(@_params$0 : tensor<3x4xf32>)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
    m = TrivialKernel()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 61, in <module>
    m = TrivialKernel()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/trivial_kernel.py", line 48, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree
nanobind: leaked 66 instances!
nanobind: leaked 16 types!
 - leaked type "iree._runtime.VmVariantList"
 - leaked type "iree._runtime.HalBufferView"
 - leaked type "iree._runtime.BufferUsage"
 - leaked type "iree._runtime.VmContext"
 - leaked type "iree._runtime.MappedMemory"
 - leaked type "iree._runtime.ArgumentPacker"
 - leaked type "iree._runtime.HalElementType"
 - leaked type "iree._runtime.VmRef"
 - leaked type "iree._runtime.VmModule"
 - leaked type "iree._runtime.HalDevice"
 - leaked type "iree._runtime._InvokeStatics"
 - ... skipped remainder
nanobind: leaked 78 functions!
 - leaked function ""
 - leaked function "lookup_function"
 - leaked function "__eq__"
 - leaked function ""
 - leaked function "__iree_vm_type__"
 - leaked function "__or__"
 - leaked function "__init__"
 - leaked function "create_device_by_uri"
 - leaked function ""
 - leaked function "invoke"
 - leaked function "__init__"
 - ... skipped remainder
nanobind: this is likely caused by a reference counting issue in the binding code.

error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/trivial_kernel.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/trivial_kernel.py

error: command failed with exit status: 2

--

********************
FAIL: IREE_JAX :: program/fft.py (2 of 5)
******************** TEST 'IREE_JAX :: program/fft.py' FAILED ********************
Script:
--
: 'RUN: at line 15';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/fft.py | /home/woongq/jax/bin/filecheck /home/woongq/iree-jax/tests/program/fft.py
--
Exit Code: 2

Command Output (stdout):
--
$ ":" "RUN: at line 15"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/fft.py"
# command stderr:
DEBUG:iree_jax:Create new Program subclass: f_f_t
DEBUG:root:DEFINE PY_ONLY: _fft = <Exportable Pure Func: <function FFT._fft at 0x7f92544a2290>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
    m = FFT()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
    return self._fft(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/fft.py", line 41, in <module>
    m = FFT()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/fft.py", line 33, in fft
    return self._fft(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

error: command failed with exit status: 1
$ "/home/woongq/jax/bin/filecheck" "/home/woongq/iree-jax/tests/program/fft.py"
# command output:
CHECK: FileCheck error: '-' is empty.
FileCheck command line: /home/woongq/iree-jax/tests/program/fft.py

error: command failed with exit status: 2

--

********************
PASS: IREE_JAX :: program/trivial_globals.py (3 of 5)
FAIL: IREE_JAX :: program/duplicate_helper.py (4 of 5)
******************** TEST 'IREE_JAX :: program/duplicate_helper.py' FAILED ********************
Script:
--
: 'RUN: at line 1';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/duplicate_helper.py
--
Exit Code: 1

Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/duplicate_helper.py"
# command stderr:
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
    print(str(Program.get_mlir_module(module)))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
    info = Program.get_info(Program._get_instance(m))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
    m = m()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
    return mdl._encode(x, y)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 67, in <module>
    print(str(Program.get_mlir_module(module)))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 377, in get_mlir_module
    info = Program.get_info(Program._get_instance(m))
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 372, in _get_instance
    m = m()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/duplicate_helper.py", line 50, in encode
    return mdl._encode(x, y)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

error: command failed with exit status: 1

--

********************
FAIL: IREE_JAX :: program/program_api_test.py (5 of 5)
******************** TEST 'IREE_JAX :: program/program_api_test.py' FAILED ********************
Script:
--
: 'RUN: at line 1';   /home/woongq/jax/bin/python /home/woongq/iree-jax/tests/program/program_api_test.py
--
Exit Code: 1

Command Output (stdout):
--
$ ":" "RUN: at line 1"
$ "/home/woongq/jax/bin/python" "/home/woongq/iree-jax/tests/program/program_api_test.py"
# command stderr:
.DEBUG:iree_jax:Create new Program subclass: hidden
.DEBUG:iree_jax:Create new Program subclass: nullary
DEBUG:iree_jax:Create new Program subclass: unary
.DEBUG:iree_jax:Create new Program subclass: Foobar
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: error
.DEBUG:iree_jax:Create new Program subclass: global
.DEBUG:iree_jax:Create new Program subclass: my_subclass
./home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py:288: DeprecationWarning: backend and device argument on jit is deprecated. You can use a `jax.sharding.Mesh` context manager or device_put the arguments before passing them to `jit`. Please see https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html for more information.
  warnings.warn(
DEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_flax_frozen_dict.<locals>.IreeJaxProgram._f at 0x7f673b4e7760>>
DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
INFO:jax._src.xla_bridge:Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
EDEBUG:iree_jax:Create new Program subclass: iree_jax
DEBUG:root:DEFINE PY_ONLY: _f = <Exportable Pure Func: <function ProgramApiTest.test_value_tracing_with_list.<locals>.IreeJaxProgram._f at 0x7f673b5384c0>>
E
======================================================================
ERROR: test_value_tracing_with_flax_frozen_dict (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
    unittest.main()
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 145, in test_value_tracing_with_flax_frozen_dict
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 139, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

======================================================================
ERROR: test_value_tracing_with_list (__main__.ProgramApiTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 163, in <module>
    unittest.main()
  File "/usr/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/usr/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/usr/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/usr/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/usr/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 338, in lower
    donate_argnums) = infer_params_fn(*args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/api.py", line 324, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 452, in common_infer_params
    in_shardings = out_shardings = _create_sharding_with_device_backend(
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 826, in _create_sharding_with_device_backend
    xb.get_backend(backend).get_default_device_assignment(1)[0])
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 710, in get_backend
    return _get_backend_uncached(platform)
  File "/home/woongq/jax/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 699, in _get_backend_uncached
    raise RuntimeError(f"Unknown backend {platform}")
jax._src.traceback_util.UnfilteredStackTrace: RuntimeError: Unknown backend iree

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 159, in test_value_tracing_with_list
    IreeJaxProgram()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 459, in __new__
    export_function()
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 454, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/woongq/iree-jax/iree/jax/exporter.py", line 206, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/woongq/iree-jax/iree/jax/program_api.py", line 452, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/woongq/iree-jax/tests/program/program_api_test.py", line 153, in f
    return self._f(x)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 54, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/woongq/iree-jax/iree/jax/tracing.py", line 114, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/woongq/iree-jax/iree/jax/builtins.py", line 60, in resolve_call
    lowered = self.jit_f.lower(*abstract_args)
RuntimeError: Unknown backend iree

----------------------------------------------------------------------
Ran 12 tests in 0.035s

FAILED (errors=2)

error: command failed with exit status: 1

--

********************
********************
Failed Tests (4):
  IREE_JAX :: program/duplicate_helper.py
  IREE_JAX :: program/fft.py
  IREE_JAX :: program/program_api_test.py
  IREE_JAX :: program/trivial_kernel.py

Testing Time: 0.73s
  Passed: 1
  Failed: 4
okkwon commented 1 year ago

https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.

ScottTodd commented 1 year ago

https://github.com/openxla/openxla-pjrt-plugin is the right way to use JAX+IREE.

The PJRT plugin is one way to use JAX+IREE, mostly for JIT scenarios from Python. This repository is another way, with a focus on AOT scenarios outside of Python. See https://openxla.github.io/iree/guides/ml-frameworks/jax/

Did I miss something during the setup process?

Possibly. You can see what https://github.com/iree-org/iree-jax/blob/main/.github/workflows/test_gpt2_model.yaml is doing... that runs nightly at https://github.com/iree-org/iree-jax/actions.