openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
398 stars 109 forks source link

Clean up VHLO support for arrays, functions #2121

Open mlevesquedion opened 7 months ago

mlevesquedion commented 7 months ago
### Tasks
- [ ] Add support for DenseArray
- [x] Add support for SymbolRef

Request description

Array types are used sparingly in StableHLO, mostly in the form of DenseI64ArrayAttr, altough there are also uses of DenseBoolArrayAttr and ArrayAttr. Support for array types in serialization/deserialization is limited.

ArrayAttr is handled in a generic manner, i.e. both registered and unregistered array attributes will round trip through VHLO.

However, DenseI64ArrayAttr and DenseBoolArrayAttr are supported in a specific way. They are serialized using vhlo::TensorV1Attr, which means that they have to be converted back to arrays when they are deserialized. This is handled with op type switches, so there is no generic support: arrays that occur outside of the specific list of handled attributes will deserialize into tensors rather than arrays:

$ stablehlo-translate --serialize --target=current <<<"func.func @foo() {
  stablehlo.custom_call @foo() {foo = array<i64: 1, 2, 3>} : () -> ()
  func.return
}" | stablehlo-translate --deserialize
func.func @foo() {
  stablehlo.custom_call @foo() {foo = dense<[1, 2, 3]> : tensor<3xi64>} : () -> ()
  return
}

These are other kinds of array attributes, such as DenseI32ArrayAttr, which are currently not supported:

$ stablehlo-translate --serialize --target=current <<<"func.func @foo() {
  stablehlo.custom_call @foo() {foo = array<i32: 1, 2, 3>} : () -> ()
  func.return
}" | stablehlo-translate --deserialize
...
Failed to convert: array<i32: 1, 2, 3>
<stdin>:2:3: error: failed to legalize operation 'stablehlo.custom_call' that was explicitly marked illegal
  stablehlo.custom_call @foo() {foo = array<i32: 1, 2, 3>} : () -> ()
  ^

We should consider implementing generic handling of dense array attributes so that all such attributes can round trip through VHLO.

This feature request was originally motivated by a report in the stablehlo Discord channel: https://discordapp.com/channels/999073994483433573/999074539138990131/1219706845443129426

GleasonK commented 6 months ago

Let's add array + symbolref support together

dfm commented 2 weeks ago

I wanted to check in on this issue because this just came up in JAX: https://github.com/jax-ml/jax/issues/24160

We rely on array attributes in the XLA foreign function interface for accepting vector inputs, but this doesn't currently work via the plugin interface (i.e. any open source users of JAX on GPU) because of this round tripping issue. I'm going to look into being more flexible on the FFI side, but wanted to add a +1 here to say that having support for round tripping array attributes would be useful for us!