iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.56k stars 568 forks source link

failed to legalize operation 'mhlo.cholesky' that was explicitly marked illegal #10816

Open JDAI27 opened 1 year ago

JDAI27 commented 1 year ago

What happened?

Hi,

I just follow this post (https://github.com/google/jax/issues/8074#issuecomment-1148012748) install jax and iree, and I tried to run this example code by JAX_PLATFORMS=iree JAX_IREE_BACKEND=vulkan python gaussian_process_regression.py (https://github.com/google/jax/blob/fd2f590b3ba404f60b16fd4d58339194de1421c1/examples/gaussian_process_regression.py).

And then I got this error.

Traceback (most recent call last):
  File "/Users/daydream/jax_installer/gaussian_process_regression.py", line 114, in <module>
    app.run(main)
  File "/Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/Users/daydream/jax_installer/gaussian_process_regression.py", line 99, in main
    params, momentums, scales = train_step(params, momentums, scales, x, y)
  File "/Users/daydream/jax_installer/gaussian_process_regression.py", line 85, in train_step
    grads = grad_fun(params, x, y)
  File "/Users/daydream/jax_installer/jax/jax/_src/iree.py", line 191, in compile
    iree_binary = iree.compiler.compile_str(
  File "/Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/iree/compiler/tools/core.py", line 278, in compile_str
    result = invoke_immediate(cl, immediate_input=input_bytes)
  File "/Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/iree/compiler/tools/binaries.py", line 196, in invoke_immediate
    raise CompilerToolError(process)
iree.compiler.tools.binaries.CompilerToolError: Error invoking IREE compiler tool iree-compile
Diagnostics:
/Users/daydream/jax_installer/gaussian_process_regression.py:52:0: error: failed to legalize operation 'mhlo.cholesky' that was explicitly marked illegal
/Users/daydream/jax_installer/gaussian_process_regression.py:52:0: note: called from
compilation failed

Invoked with:
 iree-compile /Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=mhlo --iree-vm-bytecode-module-output-format=flatbuffer-binary --iree-hal-target-backends=vulkan-spirv --iree-llvm-embedded-linker-path=/Users/daydream/anaconda3/envs/jax_gpu/lib/python3.9/site-packages/iree/compiler/tools/../_mlir_libs/iree-lld --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-llvm-target-triple=arm64-apple-darwin21.5.0 --iree-flow-demote-i64-to-i32 --iree-vulkan-target-triple=m1-moltenvk-macos --iree-llvm-target-cpu-features=host --iree-mhlo-demote-i64-to-i32=false
rsuderman commented 1 year ago

mhlo.cholesky probably does not have a lowering to the linalg dialect. We should forward this issue to the mhlo team to see about adding a corresponding lowering.

stellaraccident commented 1 year ago

I suspect we will want to lower this to a library call of some kind.

benvanik commented 1 year ago

Nah, it's not that bad - something in linalg_ext maybe: https://rosettacode.org/wiki/Cholesky_decomposition - there's some CUDA implementations laying around that look like something we could generate: https://github.com/bhattmansi/Implementation-of-Cholesky-Decomposition-in-GPU-using-CUDA/blob/master/parallel_chol.cu#L18-L91

rsuderman commented 1 year ago

Some intermediate progress based on: https://www.quantstart.com/articles/Cholesky-Decomposition-in-Python-and-NumPy/

Available at: https://github.com/openxla/iree/commit/670d06b56c03771c6cc5a1847eda6589ba919cf8

For whoever works on this next