iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

models/gpt2/test_jax.py failed #49

Open wangkuiyi opened 1 year ago

wangkuiyi commented 1 year ago

Problem

I tried to run the test with the command.

python iree-jax/models/gpt2/test_jax.py

The error message is attached at the end of this issue.

Reproduce

  1. Build IREE compiler and runtime with Python bindings. https://iree-org.github.io/iree/building-from-source/python-bindings-and-importers/#building-python-bindings
  2. Install IREE compiler & runtime Python bindings and iree.jax from source code. https://github.com/iree-org/iree-jax/issues/47#issuecomment-1405876566
  3. Install dependencies of iree-jax/models/gpt2.
    conda install absl-py transformers h5py
  4. Run the test.
    python iree-jax/models/gpt2/test_jax.py

0:49 $ python iree-jax/models/gpt2/test_jax.py
Running tests under Python 3.10.8: /Users/y/miniforge3/envs/iree-jax/bin/python
[ RUN      ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
I0126 20:50:00.234236 8235580032 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I0126 20:50:00.234334 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0126 20:50:00.234373 8235580032 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234399 8235580032 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234518 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0126 20:50:00.234560 8235580032 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
  warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
[       OK ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
[ RUN      ] GPT2RealWeightsTest.test_batch_one1 ('iree')
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
  warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
I0126 20:50:07.604442 8235580032 binaries.py:182] Invoke IREE Tool: /Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvm-embedded-linker-path=/Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0
[  FAILED  ] GPT2RealWeightsTest.test_batch_one1 ('iree')
======================================================================
ERROR: test_batch_one1 ('iree') (__main__.GPT2RealWeightsTest)
GPT2RealWeightsTest.test_batch_one1 ('iree')
test_batch_one('iree')
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 72, in <module>
    absltest.main()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2060, in main
    _run_in_app(run_tests, args, kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2165, in _run_in_app
    app.run(main=main_function)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2163, in main_function
    function(argv, args, kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2561, in run_tests
    result = _run_and_get_tests_result(
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2527, in _run_and_get_tests_result
    test_program = unittest.TestProgram(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 101, in __init__
    self.runTests()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 271, in runTests
    self.result = testRunner.run(self.test)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 82, in run
    return super(TextTestRunner, self).run(test)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/runner.py", line 184, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
    test(result)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 650, in __call__
    return self.run(*args, **kwds)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 591, in run
    self._callTestMethod(testMethod)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
    method()
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
    kv, x0 = encode(params, kv, prompt, 0, t)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/api.py", line 565, in cache_miss
    out_flat = call_bind_continuation(execute(*args_flat))
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2108, in __call__
    input_bufs = self.in_handler(args)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1888, in __call__
    return self.handler(input_buffers)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in shard_args
    return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in <listcomp>
    return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 392, in _shard_arg
    return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
    return [buf if buf.device() == d else buf.copy_to_device(d)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
    return [buf if buf.device() == d else buf.copy_to_device(d)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]

Invoked with: DeviceArray([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,
               0.01506208,  0.04531523],
             [ 0.04034033, -0.04861503,  0.04624869, ...,  0.08605453,
               0.00253983,  0.04318958],
             [-0.12746179,  0.04793796,  0.18410145, ...,  0.08991534,
              -0.12972379, -0.08785918],
             ...,
             [-0.04453601, -0.05483596,  0.01225674, ...,  0.10435229,
               0.09783269, -0.06952604],
             [ 0.1860082 ,  0.01665728,  0.04611587, ..., -0.09625227,
               0.07847701, -0.02245961],
             [ 0.05135201, -0.02768905,  0.0499369 , ...,  0.00704835,
               0.15519823,  0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
    kv, x0 = encode(params, kv, prompt, 0, t)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
    return [buf if buf.device() == d else buf.copy_to_device(d)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
    return [buf if buf.device() == d else buf.copy_to_device(d)
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]

Invoked with: DeviceArray([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,
               0.01506208,  0.04531523],
             [ 0.04034033, -0.04861503,  0.04624869, ...,  0.08605453,
               0.00253983,  0.04318958],
             [-0.12746179,  0.04793796,  0.18410145, ...,  0.08991534,
              -0.12972379, -0.08785918],
             ...,
             [-0.04453601, -0.05483596,  0.01225674, ...,  0.10435229,
               0.09783269, -0.06952604],
             [ 0.1860082 ,  0.01665728,  0.04611587, ..., -0.09625227,
               0.07847701, -0.02245961],
             [ 0.05135201, -0.02768905,  0.0499369 , ...,  0.00704835,
               0.15519823,  0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>

----------------------------------------------------------------------
Ran 2 tests in 9.993s

FAILED (errors=1)