openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.56k stars 400 forks source link

MHLO Extraction from XLA Compiler #12847

Open GrandChariot opened 3 months ago

GrandChariot commented 3 months ago

Hello,

I have been delving into the XLA project recently and have a few inquiries regarding accessing MHLO from the XLA compiler. The XLA compiler exhibits a broad array of optimization items, and I am keen on incorporating them into my backend compiler, which is designed for specialized hardware. Given that my compiler is built on MLIR, gaining access to MHLO from the XLA compiler would be highly beneficial. However, it appears that the XLA compiler lacks an exit point for this.

As such, I have two questions:

  1. As of the current state, is there a method to access intermediate results in the form of MHLO?
  2. Are there plans in the pipeline to offer an interface that allows for obtaining results in the form of MHLO or LMHLO?

I believe it would be great if such plans were to be initiated. This would enable us to deploy the XLA compiler's optimization items in our backend compiler and hardware in the future.

Best Regards, Jaeyeon Kim

cheshire commented 3 months ago
  1. LMHLO doesn't exist anymore I believe.

  2. Technically, roundtripping back to MHLO should work, did you try to use the available translate tool to convert the intermediate HLO result back to MHLO?

@GleasonK should be able to provide more details.

GleasonK commented 3 months ago

Hello!

Frameworks/compilers interested in an MLIR interface for XLA should use StableHLO instead of MHLO, which should be possible so long as the HLO passes/optimizations used are not hardware specific. As an added benefit this would provide an easy entrypoint from JAX/TF/PT -- see tutorials for details on this.

For exporting HLO to MHLO/StableHLO: PyTorch/XLA's export to StableHLO goes via HLO->MHLO->StableHLO, so can provide a good code of how to do this programmatically: stablehlo_helper.cc

As for APIs to "access intermediate results", I'm not sure what you have in mind, a custom pipeline can be built which runs the HLO passes of interest and then export to MHLO/StableHLO (see APIs above). Alternatively you could look into dumping HLO at a specific point and using xla-translate (example) to convert from HLO to MHLO, and mlir-hlo-opt (example) if you want to go to StableHLO via available tooling.

GrandChariot commented 3 months ago

@GleasonK Thank you! Your answer was indeed insightful. I will proceed to use the converted StableHLO from Pytorch/xla. I do have a few more inquiries, however:

  1. Pytorch/xla is exclusively for TPU, but the hardware I intend to use in my compiler is not TPUs. Will there be any issues concerning optimization or compatibility if I utilize the StableHLO generated in this pytorch/xla?
  2. To my understanding, the operations of MHLO and StableHLO do not correspond on a 1:1 basis. Does this mean we can still leverage all the optimizations offered by the XLA Compiler even if we use StableHLO instead of MHLO?
cheshire commented 3 months ago

Pytorch/xla is exclusively for TPU

It isn't, we use it for GPUs as well.