data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
204 stars 42 forks source link

Should get_namespace support more than arrays? #799

Open NeilGirdhar opened 2 months ago

NeilGirdhar commented 2 months ago

Background

Consider a probability distribution library such as efax. To add Array API support, each probability distribution class will contain a number of parameters, and it makes sense that they will all be from the same namespace. Thus, the a standard pattern for methods that use the parameters will be to query the namespace of self by feeding in all of the parameters to get_namespace.

This pattern is not unique to efax. I imagine it will pop up in SciPy's future distribution classes (that are being developed and will support the Array API). It could be added to any object exposing the "Jax PyTrees" interface (see the registry) or generally any aggregate structure with a homogenous set of arrays.

Motivation

To simplify getting the namespace in functions that interact with aggregate structure containing homogeneous sets of arrays.

Example

Suppose that Distribution is an aggregate structure containing array-valued parameters. Instead of:

def f(x: Distribution, y: Distribution, z: Array):
  xp = x.get_namespace()  # Call method to get namespace.
  assert y.get_namespace() == xp  # Call method and check that it's the same namespace.
  assert get_namespace(z) == xp  # Check that it's the same namespace.

we would like to simply have:

def f(x: Distribution, y: Distribution, z: Array):
  xp = get_namespace(x, y, z)  # One simple line

Proposal

Extend get_namespace(o) to first read o.__namespace_arrays__, which returns an iterable of arrays that get_namespace can use as before.

Thus, instead of aggregate structures proving a method that queries the namespace like this function, we would instead have

class Distribution:
    def __namespace_arrays__(self) -> Iterable[Array]:
        return (getattr(self, field.name) for field in fields(self))

A simple recursive extension of get_namespace is illustrated here.

Alternative proposal

One alternative is to support __array_namespace__ on all inputs to get_namespace. Thus, we would have

class Distribution:
    def __array_namespace__(self, api_version: str, use_compat: bool) -> ArrayNameSpace:
        return get_namespace(*(getattr(self, field.name) for field in fields(self)),
                             api_version=api_version,
                             use_compat=use_compat)

The problem with this is that it complicates extending the parameter specification of get_namespace.

betatim commented 2 months ago

In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). https://github.com/scikit-learn/scikit-learn/blob/19c068f64249f95f745962b42a4dd581c7738218/sklearn/utils/_array_api.py#L473

Could you do something like that in efax? Or asked differently, aren't you going to end up having something like this sooner or later anyway, in which case it could also take care of dealing with "things that aren't arrays but contain them"?

NeilGirdhar commented 2 months ago

How is the linked function related? It doesn't deal with aggregate structures, which is the motivation for this proposal.

asmeurer commented 2 months ago

get_namespace isn't actually part of the array API, it's part of the compat library. The array API defines x.__array_namespace__. The compat library get_namespace() (which is also called array_namespace()) is just a wrapper around calling this method which manually returns the compat layer namespace when necessary. Maybe this should be made clearer in the documentation.

I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?

In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). scikit-learn/scikit-learn@19c068f/sklearn/utils/_array_api.py#L473

Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.

NeilGirdhar commented 2 months ago

I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?

Yes, but it's just a question of convenience. Sometimes, you have a method that accepts an aggregate object (say, self) and some arrays (say, x). I guess you could expand out the aggregate object into its component arrays and pass them to get_namespace(self.a, self.b, self.c, x). I'm proposing the convenience of get_namespace(self, x). It's just simplicity.

betatim commented 2 months ago

Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__. A feature to skip certain types seems it should be generally useful and easy to implement.

If you want to, go for it. No strong feelings from my side. I have a slight preference/find it nicer to keep the get_namespace in the compat library simple. At least I can see a future happening where it accumulates "all the useful things" from the various array consuming libraries and then becomes quite unwieldy.

The reason I linked to the custom get_namespace in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace that implements things that are convenient for it. efax could define its own get_namespace that makes dealing with the types that occur in efax convenient.

NeilGirdhar commented 2 months ago

The reason I linked to the custom get_namespace in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace that implements things that are convenient for it.

Ah, right, that makes sense!

efax could define its own get_namespace that makes dealing with the types that occur in efax convenient.

Right, which is what I'm doing. The reason I suggested upstreaming aggregate structure support is in case there are ever functions that accept aggregate structure types from different libraries.