iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

models/gpt2/test_jax.py failed #54

Closed wangkuiyi closed 1 year ago

wangkuiyi commented 1 year ago

When I ran models/gpt2/test_jax.py with the backend "iree", I got the following error. (The test passed when backend="cpu".)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/absl/testing/parameterized.py", line 320, in bound_param_test
    return test_method(self, *testcase_params)
  File "/Users/y/w/iree-ios/iree-jax/models/gpt2/test_jax.py", line 64, in test_batch_one
    kv, x0 = encode(params, kv, prompt, 0, t)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in _array_shard_arg
    return [buf if buf.device() == d else buf.copy_to_device(d)
  File "/Users/y/miniforge3/envs/iree-jax/lib/python3.10/site-packages/jax/_src/array.py", line 645, in <listcomp>
    return [buf if buf.device() == d else buf.copy_to_device(d)
TypeError: (): incompatible function arguments. The following argument types are supported:
    1. (self: xla::PyBuffer::pyobject, arg0: jaxlib.xla_extension.Device) -> StatusOr[object]

Invoked with: DeviceArray([[-0.11010301, -0.03926672,  0.03310751, ..., -0.1363697 ,
               0.01506208,  0.04531523],
             [ 0.04034033, -0.04861503,  0.04624869, ...,  0.08605453,
               0.00253983,  0.04318958],
             [-0.12746179,  0.04793796,  0.18410145, ...,  0.08991534,
              -0.12972379, -0.08785918],
             ...,
             [-0.04453601, -0.05483596,  0.01225674, ...,  0.10435229,
               0.09783269, -0.06952604],
             [ 0.1860082 ,  0.01665728,  0.04611587, ..., -0.09625227,
               0.07847701, -0.02245961],
             [ 0.05135201, -0.02768905,  0.0499369 , ...,  0.00704835,
               0.15519823,  0.12067825]], dtype=float32), <jax._src.iree.IreeDevice object at 0x126adbee0>

For a full list of stdout and stderr, please go to https://gist.github.com/wangkuiyi/90859c3a33af4ebaddd357e527add33d.

I didn't know that IREE could work as a backend for "jax.jit" until I read the following code:

https://github.com/iree-org/iree-jax/blob/26006ef5842a604e28ea71e65e9224ad20f028e9/models/gpt2/test_jax.py#L56-L57

I'm not sure if the above error happened because I didn't build JAX and jaxlib from source code but instead used pip to install them.

wangkuiyi commented 1 year ago

I can run python/export.py to get the MLIR and vmfb files. Do I need to worry about the error above if all I want to do is run the vmfb with the IREE runtime?

jpienaar commented 1 year ago

No, the above was due to a missing transfer/change in API. I think that is resolved, but either way doesn't affect anything VMFB side.