pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 478 forks source link

Do I have to implement PjRtLoadedExecutable::GetHloModules when `XLA_STABLEHLO_COMPILE=1` ? #6759

Open Nullkooland opened 7 months ago

Nullkooland commented 7 months ago

❓ Questions and Help

Hi, I'm from a hardware vendor and we want to implement a PJRT plugin for our DSA accelerator. We have our own MLIR-based compiler stack and it takes StableHLO as the input IR.

I'm new to PJRT, according to the description, PJRT API is supposed to be compiler-agnostic and should not assume a PJRT plugin's compiler backend must be XLA. However, in PyTorch/XLA's PJRT runtime: PjRtComputationClient::Compile, it calls PjRtLoadedExecutable::GetHloModules (which we left unimplemented in our PjRtLoadedExecutable implementation) and expects returning of valid xla::HloModule:

https://github.com/pytorch/xla/blob/19b83830ac4ee3a39d99abaf154f485c2399f47a/torch_xla/csrc/runtime/pjrt_computation_client.cc#L585

My question is, does PyTorch/XLA's PjRtComputationClient requires these xla::HloModule for execution? If not, when user set XLA_STABLEHLO_COMPILE=1, PyTorch/XLA should not expect the compiled PjRtLoadedExecutable has anything to do with XLA/HLO related stuff.

Nullkooland commented 7 months ago

Related issue in OpenXLA repo: https://github.com/openxla/xla/issues/3282

It seems that the PjRtLoadedExecutable::GetHloModules, which maps to PJRT_Executable_OptimizedProgram in PJRT CAPI, is only used for debugging. So I guess PyTorch/XLA should not rely on this API and should function OK if it is left unimplemented.

JackCaoG commented 7 months ago

@will-cromar can you take a look at this one?

will-cromar commented 7 months ago

Our runtime wrapper API does require an XlaComputation (ie, an HLO module). This is how we access the input and output shapes/shardings. There are better APIs available in PJRT now for this purpose, but these came after our original integration with PJRT. Relying on GetHloModules is really just some leftover technical debt -- in practice, the backends we developed around (TPU and StreamExecutor GPU) are using XLA compilers so this was always available.

Skimming through our code, I do believe it is be possible to remove our hard dependency on XlaComputation (and thus GetHloModule) for compilation. I'm not as sure how our persistent caching implementation will work without it, but that's optional. cc @jonb377

@lsy323, do you know if it's possible to produce an XlaComputation from a stable HLO module as a short term hack?