Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:
$ git clone https://github.com/google/iree-jax.git
$ cd iree-jax
$ python -m pip install -e .[test,xla,cpu] -f https://github.com/google/iree/releases
$ python --version
Python 3.9.7
$ vi tiny.py
Copy the example in README. Add the following two statement:
import jax.numpy as jnp
from collections import namedtuple
$ python tiny.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
File "/home/u89062/iree-jax/examples/tiny.py", line 85, in <module>
m = TrivialKernel()
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 559, in __new__
export_function()
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 554, in export_function
info.export_module.def_func(invoke_with_self,
File "/home/u89062/iree-jax/iree/jax/exporter.py", line 208, in def_func
return_py_value = f(*argument_py_tree)
File "/home/u89062/iree-jax/iree/jax/program_api.py", line 552, in invoke_with_self
return func_def.callable(self, *args, **kwargs)
File "/home/u89062/iree-jax/examples/tiny.py", line 61, in run
result = self._linear(multiplier, self._params.x, self._params.b)
File "/home/u89062/iree-jax/iree/jax/tracing.py", line 55, in __call__
return current_ir_trace().handle_call(self, args, kwargs)
File "/home/u89062/iree-jax/iree/jax/tracing.py", line 115, in handle_call
return target.resolve_call(self, *args, **kwargs)
File "/home/u89062/iree-jax/iree/jax/builtins.py", line 64, in resolve_call
imported_main_symbol_name = jax_utils.import_main_function(
File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 116, in import_main_function
source_module = import_module(context, source_module)
File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 95, in import_module
raise ValueError(
ValueError: Attempted to import a non-module (did you enable MLIR in JAX?). Got module @jit__linear.2 {
func public @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<3x4xf32>) -> tensor<3x4xf32> {
%0 = mhlo.multiply %arg0, %arg1 : tensor<3x4xf32>
%1 = mhlo.add %0, %arg2 : tensor<3x4xf32>
return %1 : tensor<3x4xf32>
}
}
Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:
Thanks!