Open JuliaPoo opened 2 months ago
So the next step here would be to implement wrappers for this function in array_api_compat. That way we will be able to see just how complex the required changes are, and also so we can verify that there aren't other incompatibilities, since the tests can't check things later in the test if the things earlier in them the fail.
The way I would do this is to make a PR to the compat library that
Sparse and JAX are currently not tested in the compat library CI, because their support is entirely in the libraries themselves. So what you can do is make a simple wrapper namespace that just wraps top_k
and nothing else. Then add a CI script like
# .github/array-api-tests-jax.yml
name: Array API Tests (JAX)
on: [push, pull_request]
jobs:
array-api-tests-jax:
uses: ./.github/workflows/array-api-tests.yml
with:
package-name: jax
pytest-extra-args: -k top_k
This should hopefully be straightforward, but let me know if you run into any issues.
top_k
is accepted we can reuse it for the NumPy 1.26 wrapper. For now, though, you can either copy it as the NumPy wrapper to your compat PR to verify it, or change the NumPy dev CI job to point to your NumPy PR. Here are some development notes for the compat library which should be helpful https://data-apis.org/array-api-compat/dev/index.html, but also feel free to ask any questions here.
The main purpose is to just an idea of what the wrappers will look like and to get CI log showing the tests pass (or if something is too hard to wrap, what the error is). So you don't need to worry too much about making everything perfect and mergeable. If top_k
is eventually added to the standard we can cleanup these PRs and use them.
If you see something that could be changed in array-api-compat or the test suite to make this process easier, make a note of it. We are going to want to be able to repeat this whole process in the future any time a new function is proposed for inclusion in the array API.
torch.topk is only implemented for certain dtypes (e.g., topk_cpu does not implement UInt16).
Quite a few things in PyTorch don't work with smaller uint dtypes. They are skipped in the CI, so you don't need to worry about that. If the torch wrapper is just top_k = topk
and these tests pass, that will be a good sign that the proposed specification matches the existing PyTorch implementation.
The purpose of this PR is to continue several threads of discussion regarding
top_k
.This follows roughly the specifications of
top_k
in data-apis/array-api#722, with slight modifications to the API to increase compatibility:Modifications:
mode: Literal["largest", "smallest"]
is replaced withlargest: bool
axis
is no longer a kw-only arg. This makestorch.topk
slightly more compatible.The tests implemented here follows the proposed
top_k
implementation at numpy/numpy#26666.Compatibility concerns with prior art:
numpy
: None if numpy/numpy#26666 gets merged.torch
: In torch the API name istopk
instead, andtorch.topk
is only implemented for certain dtypes (e.g.,topk_cpu
does not implementUInt16
).tensorflow
:axis
keyword does not exist, behaves likeaxis=-1
.JAX
: Same astensorflow
Dask
:largest
keyword does not exist,largest
flag is instead determined by the sign ofk
.