data-apis / array-api-tests

Test suite for the PyData Array APIs standard
https://data-apis.org/array-api-tests/
MIT License
63 stars 39 forks source link

WIP: `top_k` tests #274

Open JuliaPoo opened 2 months ago

JuliaPoo commented 2 months ago

The purpose of this PR is to continue several threads of discussion regarding top_k.

This follows roughly the specifications of top_k in data-apis/array-api#722, with slight modifications to the API to increase compatibility:

def top_k(
    x: array,
    k: int,
    /,
    axis: Optional[int] = None,
    *,
    largest: bool = True,
) -> Tuple[array, array]:
    ...

Modifications:

The tests implemented here follows the proposed top_k implementation at numpy/numpy#26666.

Compatibility concerns with prior art:

asmeurer commented 2 months ago

So the next step here would be to implement wrappers for this function in array_api_compat. That way we will be able to see just how complex the required changes are, and also so we can verify that there aren't other incompatibilities, since the tests can't check things later in the test if the things earlier in them the fail.

The way I would do this is to make a PR to the compat library that

  1. Modifies the actions file to point to this PR: https://github.com/data-apis/array-api-compat/blob/ac15c526d9769f77c780958a00097dfd183a2a37/.github/workflows/array-api-tests.yml#L53
  2. Adds compat wrappers for the different libraries.
  3. Sparse and JAX are currently not tested in the compat library CI, because their support is entirely in the libraries themselves. So what you can do is make a simple wrapper namespace that just wraps top_k and nothing else. Then add a CI script like

    # .github/array-api-tests-jax.yml
    name: Array API Tests (JAX)
    
    on: [push, pull_request]
    
    jobs:
     array-api-tests-jax:
       uses: ./.github/workflows/array-api-tests.yml
       with:
         package-name: jax
         pytest-extra-args: -k top_k

    This should hopefully be straightforward, but let me know if you run into any issues.

  4. CuPy cannot be tested on CI. However, CuPy should be identical to NumPy, so if you don't have access to a CUDA machine, I wouldn't worry about it for now.
  5. Don't worry about tensorflow. It hasn't been included in the compat library at all yet.
  6. Considering https://github.com/numpy/numpy/pull/26666 is a simple pure Python implementation, if top_k is accepted we can reuse it for the NumPy 1.26 wrapper. For now, though, you can either copy it as the NumPy wrapper to your compat PR to verify it, or change the NumPy dev CI job to point to your NumPy PR.

Here are some development notes for the compat library which should be helpful https://data-apis.org/array-api-compat/dev/index.html, but also feel free to ask any questions here.

The main purpose is to just an idea of what the wrappers will look like and to get CI log showing the tests pass (or if something is too hard to wrap, what the error is). So you don't need to worry too much about making everything perfect and mergeable. If top_k is eventually added to the standard we can cleanup these PRs and use them.

If you see something that could be changed in array-api-compat or the test suite to make this process easier, make a note of it. We are going to want to be able to repeat this whole process in the future any time a new function is proposed for inclusion in the array API.

asmeurer commented 2 months ago

torch.topk is only implemented for certain dtypes (e.g., topk_cpu does not implement UInt16).

Quite a few things in PyTorch don't work with smaller uint dtypes. They are skipped in the CI, so you don't need to worry about that. If the torch wrapper is just top_k = topk and these tests pass, that will be a good sign that the proposed specification matches the existing PyTorch implementation.