iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 620 forks source link

[JAX + IREE] - failed to legalize operation 'mhlo.triangular_solve' #11018

Open lausena opened 2 years ago

lausena commented 2 years ago

What happened?

Using jax + iree Mac M1 Max fails call to linagl.inv() on a a jax DeviceArray().

Steps to reproduce your issue:

>>>import jax.numpy as jnp
>>>
>>> a=jnp.array([[1., 2.], [3., 4.]])
>>> a
DeviceArray([[1., 2.],
             [3., 4.]], dtype=float32)
>>> jnp.linalg.inv(a)
......
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/lausena/miniforge3/envs/tfpy310/lib/python3.10/site-packages/jax/_src/iree.py", line 186, in compile
    iree_binary = iree.compiler.compile_str(
  File "/Users/lausena/miniforge3/envs/tfpy310/lib/python3.10/site-packages/iree/compiler/tools/core.py", line 278, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/Users/lausena/miniforge3/envs/tfpy310/lib/python3.10/site-packages/iree/compiler/tools/binaries.py", line 196, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Diagnostics:
<stdin>:1:0: error: failed to legalize operation 'mhlo.triangular_solve' that was explicitly marked illegal
<stdin>:1:0: note: called from
compilation failed

What component(s) does this issue relate to?

No response

Version information

M1 Max | Darwin 22.1.0

JAX compiled with the following extra_args extra_args += ["--iree-llvm-target-triple=arm64-apple-darwin22.1.0", "--iree-flow-demote-i64-to-i32", "--iree-vulkan-target-triple=m1-moltenvk-macos", "--iree-llvm-target-cpu-features=host", "--iree-mhlo-demote-i64-to-i32=true"]

Iree-compiler/runtime built from: https://github.com/iree-org/iree/releases/tag/candidate-20221102.315

Additional context

Previous issue that was resolved dealing with jax methods found here: https://github.com/iree-org/iree/issues/10938

As a side note--I am looking to run the following project and finding issues with jax+iree as I keep stepping through it, currently the jnp.linagl.inv method is the culprit. https://github.com/google-research/google-research/tree/master/light_field_neural_rendering Specifically, https://github.com/google-research/google-research/blob/master/light_field_neural_rendering/src/models/projector.py#L214 (but, also error with simple example given above)

antiagainst commented 2 years ago

Routing to @jpienaar given that this is missing lowering for mhlo.triangular_solve.

rsuderman commented 2 years ago

Passed onto the OpenXLA folks to take a look.

jpienaar commented 1 year ago

This is unfortunately not in near term in their priorities, still discussing with one group there. Else we'll take a look. Is the above representative of the kind of model you are looking at? (And purely out of curiosity, what application domain are you looking at if).

kulinseth commented 1 year ago

cc. @hawkinsp . This is the issue affecting us as well.

Thanks @jpienaar for taking a look.

(And purely out of curiosity, what application domain are you looking at if).

This was used in a data processing pipeline for a vision application.

jpienaar commented 1 year ago

Thanks for flagging, this was not yet prioritized. We may start with just offloading to library calls for some of these conjugate HLOs. What device backbends are most of interest to order effort?

hawkinsp commented 1 year ago

@kulinseth Clarify something for me: is this precise issue actually the issue that you have?

This particular issue says that openxla/iree doesn't have a lowering for mhlo.triangular_solve, but are you actually using IREE at the moment? Is it that you (independently) need a lowering for mhlo.triangular_solve also? Or is it that mhlo.triangular_solve forms part of a larger computation for which you have a direct lowering (e.g., matrix inversion)? Or something else?

(For IREE, the quickest way to implement this is almost certainly to call a library implementation of a triangular solve kernel in the short term, although "which one" is going to be platform specific. Happily this particular kernel is one of the standard BLAS kernels, and it probably has at least a CPU implementation on most platforms.)