data-apis / array-api

RFC document, tooling and other content related to the array API standard
https://data-apis.github.io/array-api/latest/
MIT License
205 stars 42 forks source link

Add API specifications for returning the `k` largest elements #722

Open kgryte opened 6 months ago

kgryte commented 6 months ago

This PR

Prior Art

As illustrated in the API comparison, there is currently no consistent API across array libraries for returning the k largest or smallest values.

Proposed APIs

This PR attempts to synthesize the common themes and best ideas for "top k" APIs as observed among array libraries and attempts to define APIs which adhere to specification precedent in order to promote consistent design and reduce cognitive load.

top_k

def top_k(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> Tuple[array, array]

Returns a tuple containing the k largest (or smallest) elements in x.

def top_k_indices(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the indices of the k largest (or smallest) elements in x.

def top_k_values(
    x: array,
    k: int,
    /,
    *,
    axis: Optional[int] = None,
    mode: Literal["largest", "smallest"] = "largest",
) -> array

Returns an array containing the k largest (or smallest) elements in x.

Design Decision Rationale

Questions

Considerations

The APIs included in this PR have implications for the following array libraries:

Related Links

kgryte commented 6 months ago

cc @ogrisel for visibility since you originally opened #722.

kgryte commented 5 months ago

@ogrisel Do you have opinions on whether having three separate APIs is preferable to having just argtop_k and top_k?

I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.

There may also be other combinations. The PR currently specifies

We could, e.g., only specify

or strictly complementary

If you have a feel for what is preferable, that would be great to hear!

ogrisel commented 3 months ago

Thank you very much for the survey of current implementations and API proposal. From a potential API consumer point of view the main proposal seems good for me.

About the name: topk seems to be a bit more popular than top_k in existing libraries, but using top_k might help reducing breaking changes when adopting this spec but it might be better to hear from library implementers.

Should we be more strict in specifying what should happen when k exceeds the number of elements?

I believe so. As a user I would accept the call to fail with a standard exception type such as ValueError.

In argmin and argmax, the specification requires returning the index of first occurrence when a minimum/maximum value occurs multiple times. Given that top_k* can be implemented as a partial sort, presumably we do not want specify a first occurrence restriction. Is this a reasonable assumption?

I think it's fine to allow topk to be faster by not enforcing this constraint. It would always be possible to add stable=False bool kwarg later to make it possible for the users to request stability of the results maybe at the cost of a performance penalty as done for the xp.sort function.

Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?

I wouldn't like this requirement to hurt speed of adoption by backing libraries. I don't think many users need that in practice.

Are we okay with None being the default for axis, where the default behavior is searching over a flattened array?

I don't think users have an need for this: then can flatten the input by themselves if need. But I have the impression that NumPy (and therefore Array API) often axis=None as a default convention to work on 1d flatten array anyway so I am fine with staying consistent in that regard.

However for this particular case, the default in numpy.argpartition is axis=-1 instead of axis=None. No strong opinion of which is most natural.

Should we defer adding support for specifying multiple axes until a future revision of the specification or should we go ahead and add now for parity with min and max?

That would remove a lot of value by preventing efficient parallelization by the underlying backend when running topk on a 2D array with axis=0 or axis=1. This pattern was the original motivation for #629 (e.g. k-nearest neighbors classification in scikit-learn).

I tried tracking down actual usages of top k in the wild, but I wasn't able to get a good sense on whether having only two APIs suffices or having three separate APIs is more desirable.

The benefit for the 3 function API would be to always be able to optimize memory usage by not allocating unnecessary arrays when k is large enough for this to matter. However I have no good feeling about how much this would really be a problem in practice (and how much the underlying implementation would be able to skip the extra contiguous memory allocation internally).

Something that seems missing from this spec is to specify the handling of NaN values. Maybe and extra kwarg is needed to specify if they should be considered as either smallest or largest, or if they need to be filtere out from the result (but then the result size would be data-dependent and potentially empty arrays which might also cause problems).

Also I assume that nan values are always smaller than +inf and larger than -inf but maybe not all libraries agree on that.

rgommers commented 3 weeks ago

Revisiting this topic in preparation of helping it move forward. Quick first comment on:

JAX has top_k which only returns values,

I am not sure if it changed in the meantime or you just misread the JAX docs at the time, but https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.top_k.html says that both values and indices are returned.

I had a peek at the jax.lax implementation as well, but could tell so quickly because I'm not familiar with the .bind method. @jakevdp would you mind confirming that the docs are correct here?

So it looks like PyTorch, JAX and TF all return (values, indices). That seems to align well enough with the top_k definition proposed here.

jakevdp commented 3 weeks ago

JAX's current top_k functions are in the jax.lax namespace, while the array API implementations will be in the jax.numpy namespace. So there is no issue with having different API conventions here.

The rendered documentation is misleading due to a misplaced colon in the Returns block, but JAX returns (values, indices):

In [2]: x = jax.numpy.arange(100, 110)

In [3]: jax.lax.top_k(x, 3)
Out[3]: [Array([109, 108, 107], dtype=int32), Array([9, 8, 7], dtype=int32)]
rgommers commented 3 weeks ago

Here is a PR with a draft implementation for NumPy: https://github.com/numpy/numpy/pull/26666, aligned with the top_k signature in this PR.

seberg commented 2 weeks ago

The most sane NaN handling IMO, is to sort NaNs always to the end, which however means that if you implement sort="desc" or largest values here, NaNs should end up also at the end, which is opposite of what happens for ascending sort/smallest values.

The annoyance with that is that it means sort behavior diverges for asc/desc sort beyong a [::-1] also for unstable sort.

Of course one can just leave it unspecified here. OTOH, I dunno how much that limits the usability.