Open Nullkooland opened 7 months ago
Related issue in OpenXLA repo: https://github.com/openxla/xla/issues/3282
It seems that the PjRtLoadedExecutable::GetHloModules
, which maps to PJRT_Executable_OptimizedProgram
in PJRT CAPI, is only used for debugging. So I guess PyTorch/XLA
should not rely on this API and should function OK if it is left unimplemented.
@will-cromar can you take a look at this one?
Our runtime wrapper API does require an XlaComputation
(ie, an HLO module). This is how we access the input and output shapes/shardings. There are better APIs available in PJRT now for this purpose, but these came after our original integration with PJRT. Relying on GetHloModules
is really just some leftover technical debt -- in practice, the backends we developed around (TPU and StreamExecutor GPU) are using XLA compilers so this was always available.
Skimming through our code, I do believe it is be possible to remove our hard dependency on XlaComputation
(and thus GetHloModule
) for compilation. I'm not as sure how our persistent caching implementation will work without it, but that's optional. cc @jonb377
@lsy323, do you know if it's possible to produce an XlaComputation
from a stable HLO module as a short term hack?
❓ Questions and Help
Hi, I'm from a hardware vendor and we want to implement a PJRT plugin for our DSA accelerator. We have our own MLIR-based compiler stack and it takes StableHLO as the input IR.
I'm new to PJRT, according to the description, PJRT API is supposed to be compiler-agnostic and should not assume a PJRT plugin's compiler backend must be XLA. However, in
PyTorch/XLA
's PJRT runtime:PjRtComputationClient::Compile
, it callsPjRtLoadedExecutable::GetHloModules
(which we left unimplemented in ourPjRtLoadedExecutable
implementation) and expects returning of validxla::HloModule
:https://github.com/pytorch/xla/blob/19b83830ac4ee3a39d99abaf154f485c2399f47a/torch_xla/csrc/runtime/pjrt_computation_client.cc#L585
My question is, does
PyTorch/XLA
'sPjRtComputationClient
requires thesexla::HloModule
for execution? If not, when user setXLA_STABLEHLO_COMPILE=1
,PyTorch/XLA
should not expect the compiledPjRtLoadedExecutable
has anything to do with XLA/HLO related stuff.