Closed adonath closed 3 months ago
This would be super useful indeed. It's not a small amount of work I suspect. For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea. That's probably one of the most hairy parts to do, and a good start. But correctly doing all shape calculations for all functions in the API is also a large job. Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.
Thanks a lot @rgommers for the response!
It's not a small amount of work I suspect.
I only partly agree, because there is nothing particular difficult about it. The expected behavior is well defined, the API is already defined, so there is no tricky code to be figured out. It is just a matter of implementing the already defined behavior "dilligently".
For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea.
This is indeed a great start, I was not aware of this project.
But correctly doing all shape calculations for all functions in the API is also a large job.
I think the effort could actually be limited, because looking at https://github.com/numpy/numpy/tree/main/numpy/array_api, the files already pre-group the api into operations with the same behavior in terms of shape computation, i.e. element wise, indexing, searching, statistical, etc. For each group the behavior only needs to defined once, the rest is filling in boiler plate code. In addition there is broadcasting and indexing, which always applies. I'm less sure about the dtype promotion, but this must have been coded somewhere already as well.
Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.
I agree PyTorch is already to large of a dependency. From a quick search I only found https://github.com/NeuralEnsemble/lazyarray, which seems to be un-maintained. It also has a different approach of building a graph and then delay the evaluation.
I'd like to get a better idea of the actual implementation effort and just share some more thoughts on this idea.
unique
cannot really be supported. Only if one allows "un-initialized" shapes (see below)(None, 2, 3, 4)
could be handled as a valid shape. After specific operations, such as the mean along the batch axis, the shape becomes known.
However it is not a valid shape in the Array-API. ndindex
is great! It also already provides https://quansight-labs.github.io/ndindex/api.html#ndindex.broadcast_shapes and does the indexing. So it will definitely be a dependency.ndindex
has Numpy and Cython as an optional dependency. I think the situation will be exactly the same for such a lazy implementation. I don't think anything else would be needed.Sometimes it might useful to work with "un-initialized shapes". For example the length of a "batch axes" might not be known and often one would not care about it either. As "batches" are treated independent (None, 2, 3, 4) could be handled as a valid shape. After specific operations, such as the mean along the batch axis, the shape becomes known. However it is not a valid shape in the Array-API.
Is this not a valid shape? The spec for shape
says
out (Tuple[Optional[int], …]) – array dimensions. An array dimension must be None if and only if a dimension is unknown.
Indeed this is already part of the spec. I missed that before!
Actually I'd be interested in starting a repo and playing around with this a bit. Any preference for a name @rgommers or @lucascolley? What about ndshape
or xpshape
for example?
In addition to the already available implementations of the array api I think it could be interesting to have a lazy / meta implementation of the standard. What I mean is a small, minimal dependency, standalone library, compatible with the array api, that provides inference of the shape and dtype of resulting arrays, without ever initializing the data and executing any flops.
PyTorch already has something like this with the
"meta"
device. For example:However this misses for example the device handling, as the device is constrained to
"meta"
. I presume that dask must have something very similar. Jax also must have something very similar for the jitted computations. However I think it is only exposed to users with a different API viajax.eval_shape()
and not via an "meta" array object.Similarly to the torch example one would use a hypothetical library
lazy_array_api
:The use case I have in mind is mostly debugging, validation and testing of computational intense algorithms ("dry runs"). For now I just wanted to share the idea and bring it up for discussion.