google / heir

A compiler for homomorphic encryption
https://heir.dev/
Apache License 2.0
305 stars 46 forks source link

feat: StableHLO as a front end for HEIR #738

Open asraa opened 3 months ago

asraa commented 3 months ago

I've recently learned about StableHLO, and been pretty convinced that it should be a frontend to HEIR:

  1. StableHLO is an open-source first project aiming to be a standard inside and outside of Google. It aims to be a portability layer between ML frameworks and ML compilers and is currently used by TensorFlow, JAX, PyTorch and XLA, IREE, and more.

  2. StableHLO has lowerings to standard MLIR without the use of bufferization, preserving the original tensors. This would create a pathway for high-level ML programs to RLWE schemes that utilize types and passes based on tensor types.

  3. This would also enable a frontend for quantized PyTorch models (for example, Zama takes QAT PyTorch and ingests them into concrete-ml through the ONNX format).

  4. Support for quantizing TensorFlow models using StableHLO quantization

  5. Support for qKeras models through a qKeras compilation to HLO. HLO has parity guarantees with StableHLO. qKeras offers full integer quantization, so would enable us to quantize models more efficiently than the existing use of TensorFlowLite quantization.

Will this replace TOSA?

Probably not, and it would be nice to keep both representations.

StableHLO to standard MLIR

I have some internal code that lowers stableHLO to standard MLIR (using, func, affine loops, tensor, arith - notably not memref). Some of it uses passes with tensorflow's XLA compiler right now, so I'll attach a PR with the added dep, and perhaps create a standalone tool depending on feedback.

asraa commented 3 months ago

The internal code that runs the stableHLO to standard MLIR lowering does:

  1. stablehlo-legalize-to-hlo
  2. Create XLA HLO module
  3. Partition computation using XLA utils
  4. Convert subgraphs to MLIR functions using standard MLIR <- this is where the linalg to affine loops with tensor lowerings exist (internally)
  5. Simplify some bound checks and canonicalize.

There also exists stable-hlo-legalize-to-tosa, although I don't know if it will be fully supported. The downside here is our tosa lowering path requires bufferization.

Then there is also stablehlo-legalize-to-linalg.

IREE also has a number of passes...

johnmatter commented 3 months ago

Hi Asra. OpenXLA has some good tutorials for lowering Jax/pytorch/tensorflow to StableHLO.

For anyone who prefers video to text, this video from April 2024 covers the same information.

asraa commented 2 months ago

Thank you! Re-visiting this now after the break. There's two parts to this really

(1) Lowering from JAX/PyTorch/TF to StableHLO (watching the video right now!) (2) Lowering StableHLO to standard MLIR (tensor+affine+arith) in HEIR

I'll explore the TF -> StableHLO path (with the StableHLO quantizer) first so I can get some tests and really focus on (2)