data-apis / array-api-extra

Extra array functions built on top of the array API standard.
http://data-apis.org/array-api-extra/
MIT License
2 stars 1 forks source link

ENH/API: xp-bound namespaces, array-api-compat #6

Open lucascolley opened 1 week ago

lucascolley commented 1 week ago

Currently, functions of this package require passing a standard-compatible namespace as xp=xp. This works fine, but there have been suggestions that it might be nice to avoid this requirement. There are at least a few ways we could go about this:

(1) xpx.bind_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.bind_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xp.sum(x)
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def bind_namespace(xp: ModuleType) -> ModuleType:
    class BoundNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return AttributeError(...)

    return BoundNamespace(xp)

I like this idea. If we encounter use cases where a library wants to use multiple xpx functions in the same local scope and finds the xp=xp pattern too cumbersome, I think we should add this. I think we can leave it out for now until that situation arises.

(2) xpx.extra_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.extra_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xpx.sum(x)  # XXX: xpx instead of xp
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def extra_namespace(xp: ModuleType) -> ModuleType:
    class ExtraNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return getattr(xp, name)  # XXX: delegate to xp instead of error

    return ExtraNamespace(xp)

I would not want to add this yet. I think we should keep separation between the standard namespace and the 'extra' namespace, at least until this library matures.

(3) Use array_api_compat.array_namespace internally

This would provide the most flexible API and be the least LOC to use. One could use xpx functions on standard-incompatible arrays, and let array-api-compat handle the compatibility, without having to pass an xp argument.

We don't yet have a use case where it is clearly beneficial to be able to pass standard-incompatible arrays. Consumer libraries using array-api-extra would already be computing with standard-compatible arrays internally. I don't see the need to support the following use case:

import torch
import array_api_strict as xpx
...
x = torch.asarray([1, 2, 3])
xpx.some_func(x)             # works
torch.some_standard_func(x)  # does not work

Another complication is that consumer libraries like SciPy wrap array_namespace to provide custom behaviour for scalars and other types. We would want the internal array_namespace to be the consumer library's wrapped version rather than the base one from array-api-compat.

I'm also not sure that the 1 LOC save over option (1) of this post for standard-compatible arrays is worth introducing a dependency on array-api-compat.

Overall, this would complicate things a lot with situations of co-vendoring array-api-compat and array-api-extra, which is the primary use-case for the library right now. This might be a better idea in the future if a need for handling standard-incompatible arrays arises (for example, if one wants to use functions from xpx with just a single library).

izaid commented 1 week ago

Hi @lucascolley! Thanks for writing this up and making this library. I think it's really helpful for the ecosystem to have something like this, and also potentially a good place for staging things the standard may or may not want to adopt.

I read through all the above and will add my two cents.

To be honest, I think (3) is the way to go. This is based on we want to make this as easy to use as possible. I know it sounds silly, but I think adding in extra functions like xpx.bind_namespace will discourage people from use. And it's nice to work in the same way as array-api-compat. Regarding vendoring, I would just make array-api-compat a hard dependency.

On the issue of Python scalars and lists and that kind of thing... I think there needs to be a solution to this, though not sure if this is in the scope of array-api-extra or something else.