Open JeppeKlitgaard opened 2 years ago
What about:
x.at[jnp.uint32(idx)]
In simple tests, it works.
Interestingly, for promise_in_bounds
mode with get
and clip
mode, x.at[jnp.uint32(-1)]
will be cliped to x.at[0]
.
I haven't tried anything like that, but it would still be a 'hacky' solution and doesn't look like something that would be guaranteed stabled across jax versions. I think the proposed solution (or something similar), would be a good thing to implement in the longer term.
@JeppeKlitgaard you are right! I just provide a workaround with good performance for interested people.
However, I think we cannot simply add a mode, but need an extra parameter, since "negative_as_oob" can have different oob behavior! In addition, for promise_in_bounds
mode with get
and clip
mode, what behavior should negative indices have?
If we need clipping to 0 behavior, manually clip indices can be a nice choice.
If we need dropping behavior, lax.bitcast_convert_type
can be a nice choice. I think bitcast between i32 and u32 is stable enough for most platform. Of course it doesn't work if 2^31-1 isn't out of bound.
@YouJiacheng
However, I think we cannot simply add a mode, but need an extra parameter
I should clarify – negative_as_oob
was proposed as a flag, not a mode
. If it is True
, then accessing negative indices would be considered out-of-bounds and thus trigger the behaviour specified by mode
.
In addition, for promise_in_bounds mode with get and clip mode, what behavior should negative indices have?
For promise_in_bounds
with negative_as_oob
, the user has effectively guaranteed that the index is not negative. If this is not the case, it would be undefined behaviour. A sensible fall-back would be similar to the current implementation, perhaps while emitting a warning (though I guess the idea is that no checks need to be performed).
If we need clipping to 0 behavior, manually clip indices can be a nice choice.
This is already implemented by the clip
mode and the expected behaviour should be obvious?
If we need dropping behavior,
lax.bitcast_convert_type
can be a nice choice.
This seems like a non-obvious, hacky solution. It may well work, but negative_as_oob
would allow us to avoid resorting to something like that.
From a cursory look at jax._src.ops.scatter
and jax._src_ops.slicing
it seems like this would require changes to the _scatter_lower
and _gather_lower
, which might preclude me from being able to attempt it.
@jakevdp could you let me know if there is any appetite for this sort of functionality? The use-case would be any problem where non-periodic boundary conditions on the array is required.
For example:
arr = jnp.arange(3)
print(arr)
> [0 1 2]
print(arr.at[4].get(mode="fill", fill_value=-1))
> -1 # As expected by boundary condition
print(arr.at[-1].get(mode="fill", fill_value=-1))
> 2 # Cannot apply constant boundary condition to left side
My proposal would yield:
arr = jnp.arange(3)
print(arr)
> [0 1 2]
print(arr.at[4].get(negative_as_oob=True, mode="fill", fill_value=-1))
> -1 # As expected by boundary condition
print(arr.at[-1].get(negative_as_oob=True, mode="fill", fill_value=-1))
> -1 # As expected by boundary condition
I'd like @hawkinsp to weigh-in because he's been thinking a lot about scatter/gather modes recently.
I am currently working on an Ising Model simulation using JAX and thought it might be nice to optionally disable negative indices in the
.at[...]
methods of arrays.Currently the
mode
parameter can be used to define out-of-bounds behaviour (which is very useful), but it requires a bit of a 'hack' to get negative indices to be regarded as out-of-bound, which might be useful in some situations, such as mine.I would propose adding a flag to the
get
,set
, etc methods, which would regard negative indices as out-of-bounds (thus triggering the behaviour defined bymode
).negative_as_oob: bool - If True negative indices are considered out-of-bounds instead of wrapping around the axis