jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

Enforce that IR crossing the backend plugin API is only `stablehlo`. #15608

Open hawkinsp opened 1 year ago

hawkinsp commented 1 year ago

Currently JAX emits a mixture of chlo and stablehlo and expects the backend (plugin) to handle it.

chlo does not promise any form of backward or forward compatibility, so to have a stable plugin API JAX needs to lower away any chlo pieces before handing them to the backend.

jpienaar commented 1 year ago

The mix is actually a preferred situation for OpenXLA today as CHLO has some better dynamic shape support (as well as composed ops that have different libraries/expansions for different backbends that would have to be matched back up if lowered to StableHLO), vice versa StableHLO has some composed ops that probably should be in CHLO.

We'd also want MLProgram to be usable along this path.

Where there is no loss of representation or semantics, the more stable form is preferred. Although even there there is a cost to lowering that may not be required for usage.

Could this be a separate pass? E.g., convert_to_long_term_stable which incur overheads of conversions to years stable format (which I believe is at least two dialect conversions) and so it is a serialize step.

hawkinsp commented 1 year ago

@jpienaar it can be a separate pass, but the question is "what dialect crosses the plugin boundary"? Unless the plugin opts out of stability, that must be a stable dialect. So chlo is out at the moment. I'm fine with plugins saying "I don't care about stability" and accepting a superset.

jpienaar commented 1 year ago

I believe it is even more than that, we need a compiler API to be able to query ability. Unless we require all backbends can support all of StableHLO fully (for all types and shapes) and additionally required when StableHLO grows it's extensions/capabilities mechanisms too. Extension mechanism is stated to be dialects, so CHLO can be considered a mini version of what StableHLO extension would be, and these extensions have probably less stable claims - e.g., I think CHLO would have been able to cross the boundary using MLIR bytecode with perhaps 2 breakages in last year, not sufficient for long term storage.

With the above stability and type is under plugin control, if needed one does more work, if not less (this includes allowing skipping VHLO during serialization and verification of discardable attributes). Stability becomes a spectrum and is uniformly handled along with other abilities.

Now for pure AOT workload where one doesn't even have the plugin loaded (not sure if considered) it is from "decompose into Long Term Stable Primitives without any extensions" to "use any extension needed to best represent user intent".

burmako commented 1 year ago

"Could this be a separate pass? E.g., convert_to_long_term_stable which incur overheads of conversions to years stable format (which I believe is at least two dialect conversions) and so it is a serialize step". Right, this is the approach that we are taking in serialize_portable_artifact in tensorflow/compiler/xla/python which is used by jax2tf.

GleasonK commented 1 year ago

+1 to Eugene's comment. I'm not too familiar with JAX plugins, are they built against different versions of MLIR and we need to manage slightly different versions communicating? Curious if the requirement is "only emit StableHLO" or "the input to the plugin must be stable".

I had understood the current JAX position as "stability is a feature of jax2tf" which expects producer/consumer of different versions to communicate, with both forward and backward compatibility required. Emitting StableHLO is not enough in these cases, because only StableHLO serialized using the Serialization APIs has stability guarantees. This would mean plugins are required to deserialize the StableHLO portable artifacts.

We had discussed a prepare_for_serialization function that would take care of CHLO (and for dynamic programs would convert shape dialect ops generated by the CHLO decompositions), and emit pure StableHLO programs which can be serialized. That is what is implemented here: https://github.com/openxla/xla/blob/main/xla/python/mlir.cc#L168-L180 . If the requirement is "only emit StableHLO" then this could be split out.

The notion of "Stability becomes a spectrum" would be nice if we can do so in a way that avoids confusion, and is useful to plugin authors. Possibly: 1. stable format and dialect (StableHLO using Serialization APIs), 2. stable format, not dialect (bytecode), and 3. not stable (textual assembly), each with their own tradeoffs. Somewhat equates to "Production, Experimental, Debug" respectively.

hawkinsp commented 1 year ago

I'm not too familiar with JAX plugins, are they built against different versions of MLIR and we need to manage slightly different versions communicating?

Yes. We expect to use plugins built by third-party vendors using a different version of MLIR.