iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

ValueError: Attempted to import a non-module #5

Open ronghongbo opened 2 years ago

ronghongbo commented 2 years ago

Hello, I installed iree-jax, and tried the example shown in README. An error is shown below:

$ git clone https://github.com/google/iree-jax.git
$ cd iree-jax
$ python -m pip install -e .[test,xla,cpu] -f https://github.com/google/iree/releases
$ python --version
Python 3.9.7
$ vi tiny.py
Copy the example in README. Add the following two statement:
     import jax.numpy as jnp
     from collections import namedtuple
$ python tiny.py
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Traceback (most recent call last):
  File "/home/u89062/iree-jax/examples/tiny.py", line 85, in <module>
    m = TrivialKernel()
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 559, in __new__
    export_function()
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 554, in export_function
    info.export_module.def_func(invoke_with_self,
  File "/home/u89062/iree-jax/iree/jax/exporter.py", line 208, in def_func
    return_py_value = f(*argument_py_tree)
  File "/home/u89062/iree-jax/iree/jax/program_api.py", line 552, in invoke_with_self
    return func_def.callable(self, *args, **kwargs)
  File "/home/u89062/iree-jax/examples/tiny.py", line 61, in run
    result = self._linear(multiplier, self._params.x, self._params.b)
  File "/home/u89062/iree-jax/iree/jax/tracing.py", line 55, in __call__
    return current_ir_trace().handle_call(self, args, kwargs)
  File "/home/u89062/iree-jax/iree/jax/tracing.py", line 115, in handle_call
    return target.resolve_call(self, *args, **kwargs)
  File "/home/u89062/iree-jax/iree/jax/builtins.py", line 64, in resolve_call
    imported_main_symbol_name = jax_utils.import_main_function(
  File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 116, in import_main_function
    source_module = import_module(context, source_module)
  File "/home/u89062/iree-jax/iree/jax/jax_utils.py", line 95, in import_module
    raise ValueError(
ValueError: Attempted to import a non-module (did you enable MLIR in JAX?). Got module @jit__linear.2 {
  func public @main(%arg0: tensor<3x4xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<3x4xf32>) -> tensor<3x4xf32> {
    %0 = mhlo.multiply %arg0, %arg1 : tensor<3x4xf32>
    %1 = mhlo.add %0, %arg2 : tensor<3x4xf32>
    return %1 : tensor<3x4xf32>
  }
}

Thanks!