iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 619 forks source link

The API version of PJRT plugin is lower than `jax>0.4.20` #19223

Open PragmaTwice opened 5 hours ago

PragmaTwice commented 5 hours ago

What happened?

(jax) ➜  iree JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);"
WARNING:jax._src.xla_bridge:Platform 'iree_cpu' is experimental and not all JAX functionality may be correctly supported!
[IREE-PJRT] DEBUG: Using IREE compiler binary: /home/twice/miniconda3/envs/jax/lib/python3.12/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so
[IREE-PJRT] DEBUG: Compiler Version: 3.0.0rc20241118 @ 29c451b00ecc9f9e5466e9d1079e0d69147da700 (API version 1.4)
[IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var)
[IREE-PJRT] DEBUG: CPU driver created
F1120 14:20:14.636158   43187 pjrt_c_api_helpers.cc:241] Unexpected error status /home/twice/projects/iree/integrations/pjrt/src/iree_pjrt/common/stubs.inc:5: UNIMPLEMENTED; PJRT_Plugin_Attributes
*** Check failure stack trace: ***
    @     0x7fa413cd6fa4  absl::lts_20230802::log_internal::LogMessage::SendToLog()
    @     0x7fa413cd6ea4  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7fa413cd7349  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7fa40c7d3898  pjrt::LogFatalIfPjrtError()
    @     0x7fa40c7b74ab  xla::PjRtCApiClient::InitAttributes()
    @     0x7fa40c7b5f25  xla::PjRtCApiClient::PjRtCApiClient()
    @     0x7fa40c7c88f8  xla::GetCApiClient()
    @     0x7fa40c678fb6  nanobind::detail::func_create<>()::{lambda()#1}::__invoke()
    @     0x7fa4124c41d1  nanobind::detail::nb_func_vectorcall_complex()
    @           0x53e131  PyObject_Vectorcall
[1]    43187 IOT instruction (core dumped)  JAX_PLATFORMS=iree_cpu python -c

I tried lots of JAX versions, and found that:

Steps to reproduce your issue

follows the README of PJRT plugin:

What component(s) does this issue relate to?

Other

Version information

the latest commit in main branch

Additional context

No response

PragmaTwice commented 2 hours ago

I'll work on it soon : )