Open lausena opened 2 years ago
Routing to @jpienaar given that this is missing lowering for mhlo.triangular_solve
.
Passed onto the OpenXLA folks to take a look.
This is unfortunately not in near term in their priorities, still discussing with one group there. Else we'll take a look. Is the above representative of the kind of model you are looking at? (And purely out of curiosity, what application domain are you looking at if).
cc. @hawkinsp . This is the issue affecting us as well.
Thanks @jpienaar for taking a look.
(And purely out of curiosity, what application domain are you looking at if).
This was used in a data processing pipeline for a vision application.
Thanks for flagging, this was not yet prioritized. We may start with just offloading to library calls for some of these conjugate HLOs. What device backbends are most of interest to order effort?
@kulinseth Clarify something for me: is this precise issue actually the issue that you have?
This particular issue says that openxla/iree doesn't have a lowering for mhlo.triangular_solve
, but are you actually using IREE at the moment? Is it that you (independently) need a lowering for mhlo.triangular_solve
also? Or is it that mhlo.triangular_solve
forms part of a larger computation for which you have a direct lowering (e.g., matrix inversion)? Or something else?
(For IREE, the quickest way to implement this is almost certainly to call a library implementation of a triangular solve kernel in the short term, although "which one" is going to be platform specific. Happily this particular kernel is one of the standard BLAS kernels, and it probably has at least a CPU implementation on most platforms.)
What happened?
Using jax + iree Mac M1 Max fails call to linagl.inv() on a a jax DeviceArray().
Steps to reproduce your issue:
What component(s) does this issue relate to?
No response
Version information
M1 Max | Darwin 22.1.0
JAX compiled with the following extra_args extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin22.1.0", "--iree-flow-demote-i64-to-i32", "--iree-vulkan-target-triple=m1-moltenvk-macos", "--iree-llvm-target-cpu-features=host", "--iree-mhlo-demote-i64-to-i32=true"]
Iree-compiler/runtime built from: https://github.com/iree-org/iree/releases/tag/candidate-20221102.315
Additional context
Previous issue that was resolved dealing with jax methods found here: https://github.com/iree-org/iree/issues/10938
As a side note--I am looking to run the following project and finding issues with jax+iree as I keep stepping through it, currently the
jnp.linagl.inv
method is the culprit. https://github.com/google-research/google-research/tree/master/light_field_neural_rendering Specifically, https://github.com/google-research/google-research/blob/master/light_field_neural_rendering/src/models/projector.py#L214 (but, also error with simple example given above)