Open mlevesquedion opened 7 months ago
Let's add array + symbolref support together
I wanted to check in on this issue because this just came up in JAX: https://github.com/jax-ml/jax/issues/24160
We rely on array
attributes in the XLA foreign function interface for accepting vector inputs, but this doesn't currently work via the plugin interface (i.e. any open source users of JAX on GPU) because of this round tripping issue. I'm going to look into being more flexible on the FFI side, but wanted to add a +1 here to say that having support for round tripping array attributes would be useful for us!
Request description
Array types are used sparingly in StableHLO, mostly in the form of
DenseI64ArrayAttr
, altough there are also uses ofDenseBoolArrayAttr
andArrayAttr
. Support for array types in serialization/deserialization is limited.ArrayAttr
is handled in a generic manner, i.e. both registered and unregistered array attributes will round trip through VHLO.However,
DenseI64ArrayAttr
andDenseBoolArrayAttr
are supported in a specific way. They are serialized usingvhlo::TensorV1Attr
, which means that they have to be converted back to arrays when they are deserialized. This is handled with op type switches, so there is no generic support: arrays that occur outside of the specific list of handled attributes will deserialize into tensors rather than arrays:These are other kinds of array attributes, such as
DenseI32ArrayAttr
, which are currently not supported:We should consider implementing generic handling of dense array attributes so that all such attributes can round trip through VHLO.
This feature request was originally motivated by a report in the stablehlo Discord channel: https://discordapp.com/channels/999073994483433573/999074539138990131/1219706845443129426