Open asraa opened 5 months ago
The internal code that runs the stableHLO to standard MLIR lowering does:
There also exists stable-hlo-legalize-to-tosa
, although I don't know if it will be fully supported. The downside here is our tosa lowering path requires bufferization.
Then there is also stablehlo-legalize-to-linalg
.
IREE also has a number of passes...
Hi Asra. OpenXLA has some good tutorials for lowering Jax/pytorch/tensorflow to StableHLO.
For anyone who prefers video to text, this video from April 2024 covers the same information.
Thank you! Re-visiting this now after the break. There's two parts to this really
(1) Lowering from JAX/PyTorch/TF to StableHLO (watching the video right now!) (2) Lowering StableHLO to standard MLIR (tensor+affine+arith) in HEIR
I'll explore the TF -> StableHLO path (with the StableHLO quantizer) first so I can get some tests and really focus on (2)
I've recently learned about StableHLO, and been pretty convinced that it should be a frontend to HEIR:
StableHLO is an open-source first project aiming to be a standard inside and outside of Google. It aims to be a portability layer between ML frameworks and ML compilers and is currently used by TensorFlow, JAX, PyTorch and XLA, IREE, and more.
StableHLO has lowerings to standard MLIR without the use of bufferization, preserving the original tensors. This would create a pathway for high-level ML programs to RLWE schemes that utilize types and passes based on tensor types.
This would also enable a frontend for quantized PyTorch models (for example, Zama takes QAT PyTorch and ingests them into concrete-ml through the ONNX format).
Support for quantizing TensorFlow models using StableHLO quantization
Support for qKeras models through a qKeras compilation to HLO. HLO has parity guarantees with StableHLO. qKeras offers full integer quantization, so would enable us to quantize models more efficiently than the existing use of TensorFlowLite quantization.
Will this replace TOSA?
Probably not, and it would be nice to keep both representations.
StableHLO to standard MLIR
I have some internal code that lowers stableHLO to standard MLIR (using, func, affine loops, tensor, arith - notably not memref). Some of it uses passes with tensorflow's XLA compiler right now, so I'll attach a PR with the added dep, and perhaps create a standalone tool depending on feedback.