Closed cbourjau closed 2 months ago
It sounds reasonable to use ONNX shape inference to give a real tuple. The complication with shape inference is that it may give you some symbolic dimensions:
>>> ndx.array(shape=("N", "M"), dtype=ndx.utf8).spox_var()
<Var from spox.internal@0::Argument->arg of str[N][M]
To comply with the standard, we could set all symbolic dimensions to None
. It might be worth checking if in practice people (e.g. scikit-learn where you investigated this) are checking for None
dimensions but this is out of scope for ndonnx of course.
The other challenge is that there are some situations where even the rank is unknowable statically. Consider the following:
>>> x = ndx.array(shape=(2, 3), dtype=ndx.utf8)
>>> y = ndx.array(shape=("M",), dtype=ndx.int64)
>>> ndx.broadcast_to(x, y).spox_var()
<Var from ai.onnx@13::Expand->output of str[...]>
We would need to now raise in such situations (i.e. when shape inference gives us a None
shape).
We currently have the ndonnx.additional.shape
function that always provides a ndx.Array
and is the better way to go about obtaining the dynamic shape anyway. It also should be overridable for custom dtypes. To that end, we should also consider offloading this calculation to the dtype itself so user-defined types may implement this however they see fit too.
To comply with the standard, we could set all symbolic dimensions to None
The standard mandates a shape of type Tuple[int | None, ...]
. Rather than being non-compliant by adding str
to the union, we should consider (a) upstreaming the new type signature to the standard and (b) adding a third shape-related variant to this package (ndonnx.additional
in particular) that returns the full shape information with type tupel[int | None | str, ...]
.
The other challenge is that there are some situations where even the rank is unknowable statically. Consider the following:
Are we not making the assumption that we always know the type and at least the rank of an array throughout the code base? I think we should indeed raise in such a situation.
Rather than being non-compliant by adding str to the union, we should consider (a) upstreaming the new type signature to the standard and (b) adding a third shape-related variant to this package
I wasn't suggesting adding str
to the union but rather to set those dimensions to None
.
Upstreaming Tuple[int | str | None, ...]
into the Array API standard as the type signature seems reasonable. I'm not too sure how symbolic dimensions work across other array libraries. JAX as an example seems to have something similar but it is probably worth asking a JAX expert if it is relevant: https://jax.readthedocs.io/en/latest/export/shape_poly.html#shape-polymorphism.
One of them was that libraries may use code such as
len(array.shape)
to get the rank. While there is an explicitArray.ndim
property, the former is a perfectly fine way to get the same information, too.
Using len
is strictly worse and non-idiomatic I'd say, and it prevents keeping things lazy. I'd suggest changing the scikit-learn code.
The standard mandates a shape of type
Tuple[int | None, ...]
.
For context: it was hard to settle on a sentinel for "unknown", because there were already deviations between libraries using None
and nan
(e.g., Dask uses the latter).
There's a few subtleties here, in particular: duck typing objects like tuple[int, int]
should be fine, as long as the semantics don't change. This may be desirable not only for lazy execution, but e.g. to be able to keep everything on GPU. Such duck typing isn't really expressable with static typing, but if code works at execution time then it shouldn't matter that the int
isn't actually a true integer.
It might be worth checking if in practice people (e.g. scikit-learn where you investigated this) are checking for
None
dimensions
It's rare in my experience - and when it's necessary it's typically for code that uses the size of a dimension for something conditional (if x.shape[1] < 5: ...
), and that won't work with lazy execution anyway.
While playing with ndonnx and the array-API compliant classes of scikit-learn I ran into some issue. One of them was that libraries may use code such as
len(array.shape)
to get the rank. While there is an explicitArray.ndim
property, the former is a perfectly fine way to get the same information, too. In other cases, it may be necessary to get the number of features assuming a shape of(M, n_features)
wheren_features
is statically known by callingarray.shape[1]
. Neither example currently works.On the other hand, I realize that there are also use cases where we need the shape as a lazy array which is currently served by the
Array.shape
implemenation. We should maintain that functionality but under a different name in order to serve both use cases and to maintain compliance with the standard.