0:49 $ python iree-jax/models/gpt2/test_jax.py
Running tests under Python 3.10.8: /Users/y/miniforge3/envs/iree-jax/bin/python
[ RUN ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
I0126 20:50:00.234236 8235580032 xla_bridge.py:170] Remote TPU is not linked into jax; skipping remote TPU.
I0126 20:50:00.234334 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0126 20:50:00.234373 8235580032 xla_bridge.py:355] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234399 8235580032 xla_bridge.py:355] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0126 20:50:00.234518 8235580032 xla_bridge.py:355] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0126 20:50:00.234560 8235580032 xla_bridge.py:355] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
[ OK ] GPT2RealWeightsTest.test_batch_one0 ('cpu')
[ RUN ] GPT2RealWeightsTest.test_batch_one1 ('iree')
/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/nn/functions.py:376: DeprecationWarning: jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
I0126 20:50:07.604442 8235580032 binaries.py:182] Invoke IREE Tool: /Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=llvm-cpu --iree-llvm-embedded-linker-path=/Users/y/w/iree-ios/build/compiler/compiler/bindings/python/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0
[ FAILED ] GPT2RealWeightsTest.test_batch_one1 ('iree')
======================================================================
ERROR: test_batch_one1 ('iree') (__main__.GPT2RealWeightsTest)
GPT2RealWeightsTest.test_batch_one1 ('iree')
test_batch_one('iree')
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 72, in <module>
absltest.main()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2060, in main
_run_in_app(run_tests, args, kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2165, in _run_in_app
app.run(main=main_function)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2163, in main_function
function(argv, args, kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2561, in run_tests
result = _run_and_get_tests_result(
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/absltest.py", line 2527, in _run_and_get_tests_result
test_program = unittest.TestProgram(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 101, in __init__
self.runTests()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/main.py", line 271, in runTests
self.result = testRunner.run(self.test)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/_pretty_print_reporter.py", line 82, in run
return super(TextTestRunner, self).run(test)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/runner.py", line 184, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 84, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/suite.py", line 122, in run
test(result)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 650, in __call__
return self.run(*args, **kwds)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
return test_method(self, *testcase_params)
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
kv, x0 = encode(params, kv, prompt, 0, t)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/api.py", line 565, in cache_miss
out_flat = call_bind_continuation(execute(*args_flat))
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 2108, in __call__
input_bufs = self.in_handler(args)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 1888, in __call__
return self.handler(input_buffers)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in shard_args
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 413, in <listcomp>
return [_shard_arg(arg, devices, indices[i]) for i, arg in enumerate(args)]
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/interpreters/pxla.py", line 392, in _shard_arg
return shard_arg_handlers[type(arg)](arg, devices, arg_indices)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
return [buf if buf.device() == d else buf.copy_to_device(d)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
return [buf if buf.device() == d else buf.copy_to_device(d)
jax._src.traceback_util.UnfilteredStackTrace: TypeError: (): incompatible function arguments. The following argument types are supported:
1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]
Invoked with: DeviceArray([[-0.11010301, -0.03926672, 0.03310751, ..., -0.1363697 ,
0.01506208, 0.04531523],
[ 0.04034033, -0.04861503, 0.04624869, ..., 0.08605453,
0.00253983, 0.04318958],
[-0.12746179, 0.04793796, 0.18410145, ..., 0.08991534,
-0.12972379, -0.08785918],
...,
[-0.04453601, -0.05483596, 0.01225674, ..., 0.10435229,
0.09783269, -0.06952604],
[ 0.1860082 , 0.01665728, 0.04611587, ..., -0.09625227,
0.07847701, -0.02245961],
[ 0.05135201, -0.02768905, 0.0499369 , ..., 0.00704835,
0.15519823, 0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
return test_method(self, *testcase_params)
File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
kv, x0 = encode(params, kv, prompt, 0, t)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
return [buf if buf.device() == d else buf.copy_to_device(d)
File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
return [buf if buf.device() == d else buf.copy_to_device(d)
TypeError: (): incompatible function arguments. The following argument types are supported:
1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]
Invoked with: DeviceArray([[-0.11010301, -0.03926672, 0.03310751, ..., -0.1363697 ,
0.01506208, 0.04531523],
[ 0.04034033, -0.04861503, 0.04624869, ..., 0.08605453,
0.00253983, 0.04318958],
[-0.12746179, 0.04793796, 0.18410145, ..., 0.08991534,
-0.12972379, -0.08785918],
...,
[-0.04453601, -0.05483596, 0.01225674, ..., 0.10435229,
0.09783269, -0.06952604],
[ 0.1860082 , 0.01665728, 0.04611587, ..., -0.09625227,
0.07847701, -0.02245961],
[ 0.05135201, -0.02768905, 0.0499369 , ..., 0.00704835,
0.15519823, 0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126f37ee0>
----------------------------------------------------------------------
Ran 2 tests in 9.993s
FAILED (errors=1)
Problem
I tried to run the test with the command.
The error message is attached at the end of this issue.
Reproduce
iree.jax
from source code. https://github.com/iree-org/iree-jax/issues/47#issuecomment-1405876566iree-jax/models/gpt2
.