Open NeilGirdhar opened 2 months ago
In scikit-learn we have our own get_namespace
function which uses the get_namespace
of the Array API but also contains some useful stuff (that would be repeated in many places). https://github.com/scikit-learn/scikit-learn/blob/19c068f64249f95f745962b42a4dd581c7738218/sklearn/utils/_array_api.py#L473
Could you do something like that in efax? Or asked differently, aren't you going to end up having something like this sooner or later anyway, in which case it could also take care of dealing with "things that aren't arrays but contain them"?
How is the linked function related? It doesn't deal with aggregate structures, which is the motivation for this proposal.
get_namespace
isn't actually part of the array API, it's part of the compat library. The array API defines x.__array_namespace__
. The compat library get_namespace()
(which is also called array_namespace()
) is just a wrapper around calling this method which manually returns the compat layer namespace when necessary. Maybe this should be made clearer in the documentation.
I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?
In scikit-learn we have our own get_namespace function which uses the get_namespace of the Array API but also contains some useful stuff (that would be repeated in many places). scikit-learn/scikit-learn@19c068f/sklearn/utils/_array_api.py#L473
Happy to upstream some of those features to array_api_compat. We already implement some flags on top of __array_namespace__
. A feature to skip certain types seems it should be generally useful and easy to implement.
I'm a little confused how your proposal would work. If a function takes a Distribution object, then that function already needs to know how to extract the relevant arrays from that object in order to use the array API on them, no?
Yes, but it's just a question of convenience. Sometimes, you have a method that accepts an aggregate object (say, self
) and some arrays (say, x
). I guess you could expand out the aggregate object into its component arrays and pass them to get_namespace(self.a, self.b, self.c, x)
. I'm proposing the convenience of get_namespace(self, x)
. It's just simplicity.
Happy to upstream some of those features to array_api_compat. We already implement some flags on top of
__array_namespace__
. A feature to skip certain types seems it should be generally useful and easy to implement.
If you want to, go for it. No strong feelings from my side. I have a slight preference/find it nicer to keep the get_namespace
in the compat library simple. At least I can see a future happening where it accumulates "all the useful things" from the various array consuming libraries and then becomes quite unwieldy.
The reason I linked to the custom get_namespace
in scikit-learn is that it is an example of an array consuming library having a custom version of get_namespace
that implements things that are convenient for it. efax could define its own get_namespace
that makes dealing with the types that occur in efax convenient.
The reason I linked to the custom
get_namespace
in scikit-learn is that it is an example of an array consuming library having a custom version ofget_namespace
that implements things that are convenient for it.
Ah, right, that makes sense!
efax could define its own
get_namespace
that makes dealing with the types that occur in efax convenient.
Right, which is what I'm doing. The reason I suggested upstreaming aggregate structure support is in case there are ever functions that accept aggregate structure types from different libraries.
Background
Consider a probability distribution library such as efax. To add Array API support, each probability distribution class will contain a number of parameters, and it makes sense that they will all be from the same namespace. Thus, the a standard pattern for methods that use the parameters will be to query the namespace of
self
by feeding in all of the parameters toget_namespace
.This pattern is not unique to efax. I imagine it will pop up in SciPy's future distribution classes (that are being developed and will support the Array API). It could be added to any object exposing the "Jax PyTrees" interface (see the registry) or generally any aggregate structure with a homogenous set of arrays.
Motivation
To simplify getting the namespace in functions that interact with aggregate structure containing homogeneous sets of arrays.
Example
Suppose that
Distribution
is an aggregate structure containing array-valued parameters. Instead of:we would like to simply have:
Proposal
Extend
get_namespace(o)
to first reado.__namespace_arrays__
, which returns an iterable of arrays thatget_namespace
can use as before.Thus, instead of aggregate structures proving a method that queries the namespace like this function, we would instead have
A simple recursive extension of
get_namespace
is illustrated here.Alternative proposal
One alternative is to support
__array_namespace__
on all inputs toget_namespace
. Thus, we would haveThe problem with this is that it complicates extending the parameter specification of
get_namespace
.