Farama-Foundation / Jumpy

On-the-fly conversions between Jax and NumPy tensors
Apache License 2.0
47 stars 9 forks source link

Indexing according to jax API #31

Open bheijden opened 1 year ago

bheijden commented 1 year ago

This PR allows users to take a view on onp.ndarray that allows users to use the JAX API for array indexing. This means that the .at[idx].set(val) API can be used on numpy arrays.

A subclass of numpy.array was added following this procedure, that basically adds the add property with all the necessary functionality.

This API is completely optional, thus only available once they have converted their onp.array to jumpy.jparray by taking a view on their array (does not copy data). The subclass is compatible with any numpy method. Furthermore, once a numpy method is applied, it will also return the subclass instead of reverting back to a onp.array. This is demonstrated below:

import jumpy
import numpy as onp
import jax.numpy as jnp

# Construct numpy array
arr = onp.arange(10)  
# arr.at[0].set(1) --> fails, because a regular numpy array does not yet have the .at property

# Create a `view` on the numpy array that has the .at property (does not copy).
arr_jp = arr.view(jumpy.jparray)

# Mimics jax API for numpy arrays
new_arr = arr_jp.at[0].set(1)   # Makes a copy, similar to array indexing with jax.
another_arr = onp.add(arr_jp, 1)  # Every ordinary numpy method can still be applied "as-is" on the subclass.
hasattr(another_arr, "at")  # --> True, because of type casting.

# In `jumpy._indexing.py` we decorate the existing .view methods of jax arrays, 
# so that it matches the functionality of the numpy `.view` API.
arr = jnp.arange(10)
arr.at[0].set(1)  # --> Already works out of the box for jax arrays
arr_jp = arr.view(jumpy.jparray)  # --> Nevertheless, users may still call `.view(jumpy.jparray)` which will then be a NOOP.

# In case the array type is unsure (jax/numpy), which is often the case in jumpy code, 
# while a user still wants to use the `at[idx].set(val)` API, even for numpy arrays,
# the user can call `.view(jumpy.jparray)` on the array, before the `at[idx].set(val)` API, to ensure it is available.

See here for more info on in-place updates with jax.

Some open questions:

  1. Not sure about the subclass name jparray. Ideally, it would be ndarray, but that already exists as a variable for typing.
  2. Should we then remove the custom function _base_fns.index_update, or keep it?
  3. Should all the _factory_fns per default call .view(jp.jparray) such that the .view(jp.jparray) calls are mostly abstracted away from the users. Users will then only ever require to call .view(jp.jparray) when they create an array explicitly with onp.array and they want to use the .at[idx].set(o) API.