jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.64k stars 2.82k forks source link

Proposal `negative_as_oob` for `ndarray.at[...]` #10063

Open JeppeKlitgaard opened 2 years ago

JeppeKlitgaard commented 2 years ago

I am currently working on an Ising Model simulation using JAX and thought it might be nice to optionally disable negative indices in the .at[...] methods of arrays.

Currently the mode parameter can be used to define out-of-bounds behaviour (which is very useful), but it requires a bit of a 'hack' to get negative indices to be regarded as out-of-bound, which might be useful in some situations, such as mine.

I would propose adding a flag to the get, set, etc methods, which would regard negative indices as out-of-bounds (thus triggering the behaviour defined by mode).

YouJiacheng commented 2 years ago

What about:

x.at[jnp.uint32(idx)]

In simple tests, it works. Interestingly, for promise_in_bounds mode with get and clip mode, x.at[jnp.uint32(-1)] will be cliped to x.at[0].

JeppeKlitgaard commented 2 years ago

I haven't tried anything like that, but it would still be a 'hacky' solution and doesn't look like something that would be guaranteed stabled across jax versions. I think the proposed solution (or something similar), would be a good thing to implement in the longer term.

YouJiacheng commented 2 years ago

@JeppeKlitgaard you are right! I just provide a workaround with good performance for interested people. However, I think we cannot simply add a mode, but need an extra parameter, since "negative_as_oob" can have different oob behavior! In addition, for promise_in_bounds mode with get and clip mode, what behavior should negative indices have?

If we need clipping to 0 behavior, manually clip indices can be a nice choice. If we need dropping behavior, lax.bitcast_convert_type can be a nice choice. I think bitcast between i32 and u32 is stable enough for most platform. Of course it doesn't work if 2^31-1 isn't out of bound.

JeppeKlitgaard commented 2 years ago

@YouJiacheng

However, I think we cannot simply add a mode, but need an extra parameter

I should clarify – negative_as_oob was proposed as a flag, not a mode. If it is True, then accessing negative indices would be considered out-of-bounds and thus trigger the behaviour specified by mode.

In addition, for promise_in_bounds mode with get and clip mode, what behavior should negative indices have?

For promise_in_bounds with negative_as_oob, the user has effectively guaranteed that the index is not negative. If this is not the case, it would be undefined behaviour. A sensible fall-back would be similar to the current implementation, perhaps while emitting a warning (though I guess the idea is that no checks need to be performed).

If we need clipping to 0 behavior, manually clip indices can be a nice choice.

This is already implemented by the clip mode and the expected behaviour should be obvious?

If we need dropping behavior, lax.bitcast_convert_type can be a nice choice.

This seems like a non-obvious, hacky solution. It may well work, but negative_as_oob would allow us to avoid resorting to something like that.

JeppeKlitgaard commented 2 years ago

From a cursory look at jax._src.ops.scatter and jax._src_ops.slicing it seems like this would require changes to the _scatter_lower and _gather_lower, which might preclude me from being able to attempt it.

@jakevdp could you let me know if there is any appetite for this sort of functionality? The use-case would be any problem where non-periodic boundary conditions on the array is required.

For example:


arr = jnp.arange(3)

print(arr)
> [0 1 2]

print(arr.at[4].get(mode="fill", fill_value=-1))
> -1  # As expected by boundary condition

print(arr.at[-1].get(mode="fill", fill_value=-1))
> 2  # Cannot apply constant boundary condition to left side

My proposal would yield:


arr = jnp.arange(3)

print(arr)
> [0 1 2]

print(arr.at[4].get(negative_as_oob=True, mode="fill", fill_value=-1))
> -1  # As expected by boundary condition

print(arr.at[-1].get(negative_as_oob=True, mode="fill", fill_value=-1))
> -1  # As expected by boundary condition
jakevdp commented 2 years ago

I'd like @hawkinsp to weigh-in because he's been thinking a lot about scatter/gather modes recently.