Open PragmaTwice opened 5 hours ago
0.38
0.5x
PJRT_Plugin_Attributes
(jax) ➜ iree JAX_PLATFORMS=iree_cpu python -c "import jax; a = jax.numpy.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9]); print(a + a);" WARNING:jax._src.xla_bridge:Platform 'iree_cpu' is experimental and not all JAX functionality may be correctly supported! [IREE-PJRT] DEBUG: Using IREE compiler binary: /home/twice/miniconda3/envs/jax/lib/python3.12/site-packages/iree/compiler/_mlir_libs/libIREECompiler.so [IREE-PJRT] DEBUG: Compiler Version: 3.0.0rc20241118 @ 29c451b00ecc9f9e5466e9d1079e0d69147da700 (API version 1.4) [IREE-PJRT] DEBUG: Partitioner was not enabled. The partitioner can be enabled by setting the 'PARTITIONER_LIB_PATH' config var ('IREE_PJRT_PARTITIONER_LIB_PATH' env var) [IREE-PJRT] DEBUG: CPU driver created F1120 14:20:14.636158 43187 pjrt_c_api_helpers.cc:241] Unexpected error status /home/twice/projects/iree/integrations/pjrt/src/iree_pjrt/common/stubs.inc:5: UNIMPLEMENTED; PJRT_Plugin_Attributes *** Check failure stack trace: *** @ 0x7fa413cd6fa4 absl::lts_20230802::log_internal::LogMessage::SendToLog() @ 0x7fa413cd6ea4 absl::lts_20230802::log_internal::LogMessage::Flush() @ 0x7fa413cd7349 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7fa40c7d3898 pjrt::LogFatalIfPjrtError() @ 0x7fa40c7b74ab xla::PjRtCApiClient::InitAttributes() @ 0x7fa40c7b5f25 xla::PjRtCApiClient::PjRtCApiClient() @ 0x7fa40c7c88f8 xla::GetCApiClient() @ 0x7fa40c678fb6 nanobind::detail::func_create<>()::{lambda()#1}::__invoke() @ 0x7fa4124c41d1 nanobind::detail::nb_func_vectorcall_complex() @ 0x53e131 PyObject_Vectorcall [1] 43187 IOT instruction (core dumped) JAX_PLATFORMS=iree_cpu python -c
I tried lots of JAX versions, and found that:
jax==0.4.20
jax>0.4.20
0.4.20
0.4.35
follows the README of PJRT plugin:
iree_cpu
Other
the latest commit in main branch
No response
I'll work on it soon : )
What happened?
0.38
, but in the latest version of JAX it's around0.5x
.PJRT_Plugin_Attributes
are not supported in IREE PJRT plugin which can lead to crashes in latest version of JAX:I tried lots of JAX versions, and found that:
jax==0.4.20
(and maybe some versions lower than that)jax>0.4.20
(from0.4.20
to the latest0.4.35
), which can be split to two cases:Steps to reproduce your issue
follows the README of PJRT plugin:
iree_cpu
backendWhat component(s) does this issue relate to?
Other
Version information
the latest commit in main branch
Additional context
No response