data-apis / array-api-compat

Compatibility layer for common array libraries to support the Array API
https://data-apis.org/array-api-compat/
MIT License
69 stars 22 forks source link

Seeking alternatives for setting array values with integer indexing #62

Closed fcogidi closed 11 months ago

fcogidi commented 11 months ago

On a few occasions while using this library, I've bumped against the issue having to set array values using advanced indexing. Here is an example:

import numpy as np
import numpy.array_api as anp

def one_hot_np(array, num_classes):
    n = array.shape[0]
    categorical = np.zeros((n, num_classes))
    categorical[np.arange(n), array] = 1
    return categorical

def one_hot_anp(array, num_classes):
    one_hot = anp.zeros((array.shape[0], num_classes))
    indices = anp.stack(
        (anp.arange(array.shape[0]), anp.reshape(array, (-1,))), axis=-1
    )
    indices = anp.reshape(indices, shape=(-1, indices.shape[-1]))

    for idx in range(indices.shape[0]):
        one_hot[tuple(indices[idx, ...])] = 1

    return one_hot

I'm using the numpy.array_api namespace because it follows the API standard closely.

Is there a different (better) way of setting values of an array using integer (array) indices that adhere to the 2021.12 version of the array API standard?

For the example I gave, I'm aware that I can do something like this (but not with numpy.array_api namespace, as it only supports v2021.12):

import numpy as np
import numpy.array_api as anp

def one_hot(array, num_classes):
    id_arr = anp.eye(num_classes)
    return np.take(id_arr, array, axis=0)

But I have other cases in my codebase that follow the first pattern - looping through array indices and using basic indexing to set array values. For example, using the indices from xp.argsort to mark the top-k values. Is there a better way than looping through the indices?

asmeurer commented 11 months ago

There's a plan to add a guide for this sort of thing to the standard https://github.com/data-apis/array-api/pull/668, although there's nothing there yet for alternatives to integer indexing. Most likely your best bet is to just manually use put or integer indexing for libraries that you know have that functionality.

See also https://github.com/data-apis/array-api/issues/177 and https://github.com/data-apis/array-api/issues/629

seberg commented 11 months ago

That doesn't mean that partial advanced-indexing isn't very useful in other cases.

But for one-hot, there is a conceptual rewrite available that is probably likely faster anyway as long as num_classes is relatively small.

def one_hot(array, classes):
    classes = np.arange(num_classes, dtype=array.dtype)
    return classes == array[..., np.newaxis]

(In NumPy, you could use out=... to force whichever dtype you want on the result, which should be faster when arrays are large, but also adds a fair bit of overhead.)

fcogidi commented 11 months ago

Thank you @asmeurer and @seberg for your quick and helpful reply!