data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
205 stars 42 forks source link

A "lazy" / "meta" implementation of the array api? #728

Closed adonath closed 3 months ago

adonath commented 5 months ago

In addition to the already available implementations of the array api I think it could be interesting to have a lazy / meta implementation of the standard. What I mean is a small, minimal dependency, standalone library, compatible with the array api, that provides inference of the shape and dtype of resulting arrays, without ever initializing the data and executing any flops.

PyTorch already has something like this with the "meta" device. For example:

import torch

data = torch.ones((1000, 1000, 100), device="meta")
kernel = torch.ones((100, 10), device="meta")

result = torch.matmul(data, kernel)
print(result.shape)
print(result.dtype)

However this misses for example the device handling, as the device is constrained to "meta". I presume that dask must have something very similar. Jax also must have something very similar for the jitted computations. However I think it is only exposed to users with a different API via jax.eval_shape() and not via an "meta" array object.

Similarly to the torch example one would use a hypothetical library lazy_array_api:

import lazy_array_api as xp

data = xp.ones((1000, 1000, 100), device="cpu")
kernel = xp.ones((100, 10), device="cpu")

result = xp.matmul(data, kernel)
print(result.shape)
print(result.dtype)

The use case I have in mind is mostly debugging, validation and testing of computational intense algorithms ("dry runs"). For now I just wanted to share the idea and bring it up for discussion.

rgommers commented 5 months ago

This would be super useful indeed. It's not a small amount of work I suspect. For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea. That's probably one of the most hairy parts to do, and a good start. But correctly doing all shape calculations for all functions in the API is also a large job. Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.

adonath commented 5 months ago

Thanks a lot @rgommers for the response!

It's not a small amount of work I suspect.

I only partly agree, because there is nothing particular difficult about it. The expected behavior is well defined, the API is already defined, so there is no tricky code to be figured out. It is just a matter of implementing the already defined behavior "dilligently".

For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea.

This is indeed a great start, I was not aware of this project.

But correctly doing all shape calculations for all functions in the API is also a large job.

I think the effort could actually be limited, because looking at https://github.com/numpy/numpy/tree/main/numpy/array_api, the files already pre-group the api into operations with the same behavior in terms of shape computation, i.e. element wise, indexing, searching, statistical, etc. For each group the behavior only needs to defined once, the rest is filling in boiler plate code. In addition there is broadcasting and indexing, which always applies. I'm less sure about the dtype promotion, but this must have been coded somewhere already as well.

Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.

I agree PyTorch is already to large of a dependency. From a quick search I only found https://github.com/NeuralEnsemble/lazyarray, which seems to be un-maintained. It also has a different approach of building a graph and then delay the evaluation.

adonath commented 5 months ago

I'd like to get a better idea of the actual implementation effort and just share some more thoughts on this idea.

lucascolley commented 5 months ago

Sometimes it might useful to work with "un-initialized shapes". For example the length of a "batch axes" might not be known and often one would not care about it either. As "batches" are treated independent (None, 2, 3, 4) could be handled as a valid shape. After specific operations, such as the mean along the batch axis, the shape becomes known. However it is not a valid shape in the Array-API.

Is this not a valid shape? The spec for shape says

out (Tuple[Optional[int], …]) – array dimensions. An array dimension must be None if and only if a dimension is unknown.

adonath commented 5 months ago

Indeed this is already part of the spec. I missed that before!

adonath commented 5 months ago

Actually I'd be interested in starting a repo and playing around with this a bit. Any preference for a name @rgommers or @lucascolley? What about ndshape or xpshape for example?