Open akshatvishu opened 1 year ago
Actually we are defaulting axis=-1
, i.e. the last axis, while axis = axis % len(params.shape)
just makes sure the axis is within range (I guess something similar to _validate_dim
and get_positive_axis
).
torch.gather
doesn't default dim
at all, tf.gather
and paddle.gather
defaults axis=0
.
I agree we should change the behavior of ivy.gather
to match one of these frameworks.
Actually we are defaulting
axis=-1
, i.e. the last axis, whileaxis = axis % len(params.shape)
just makes sure the axis is within range (I guess something similar to_validate_dim
andget_positive_axis
).torch.gather
doesn't defaultdim
at all,tf.gather
andpaddle.gather
defaultsaxis=0
. I agree we should change the behavior ofivy.gather
to match one of these frameworks.
can we just use TF's get_positive_axis
and define it at ivy/utils/assertions.py
and just call it like :
# path: ivy/utils/assertions.py
def get_positive_axis_for_gather(axis,ndims):
if not isinstance(axis, int):
raise TypeError(f"{axis} must be an int; got {type(axis).__name__}")
if ndims is not None:
if 0 <= axis < ndims:
return axis
elif -ndims <= axis < 0:
return axis + ndims
else:
raise ValueError(f"{axis}={axis} out of bounds: "
f"expected {-ndims}<={axis}<{ndims}")
elif axis < 0:
raise ValueError(f"{axis} may only be negative "
f"if {ndims} is statically known.")
return axis
# path: ivy/functional/backends/jax/general.py
def gather(
params: JaxArray,
indices: JaxArray,
/,
*,
axis: int = -1,
batch_dims: int = 0,
out: Optional[JaxArray] = None,
) -> JaxArray:
axis = ivy.utils.assertions.get_positive_axis_for_gather(axis, params.ndim)#axis = axis % len(params.shape)
batch_dims = batch_dims % len(params.shape)
ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims)
result = []
if batch_dims == 0:
result = jnp.take(params, indices, axis)
else:
for b in range(batch_dims):
if b == 0:
zip_list = [(p, i) for p, i in zip(params, indices)]
else:
zip_list = [
(p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z
]
for z in zip_list:
p, i = z
r = jnp.take(p, i, axis - batch_dims)
result.append(r)
result = jnp.array(result)
result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
return result
"""
and do the same for rest of the backends
OR
can define this helper at each backend and then call it
"""
I think it would make more sense to have _get_positive_axis
under ivy/functional/general.py
and import it to all the backends from there. It's not a pure assertion after all, it does modify the given axis in most cases.
Other than that, we would need to change the axis
default value and make sure the functions that use ivy.gather
don't start failing.
Other than that, we would need to change the
axis
default value and make sure the functions that useivy.gather
don't start failing.
I will test it and only push a PR if all the test are passing.
please assign to me
please assign to me
Sure, feel free to work on it!
Bug Explanation
The current implementation of the
gather
function in Ivy normalizes theaxis
parameter differently than native implementations in frameworks likeTensorFlow
,PyTorch
orPaddle
. Specifically, the lineaxis = axis % len(params.shape)
may not be in line with howTensorFlow
,PyTorch
orPaddle
handleaxis
normalization.Steps to Reproduce Bug
we're setting the value of
axis
to be:axis = axis % len(params.shape)
across all backend ativy
Forjax
backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/jax/general.py#L123For
torch
backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/torch/general.py#L174 Fortensorflow
backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/tensorflow/general.py#L103 ForNumPy
backend:https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/numpy/general.py#L75 ForPaddle
backend:https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/paddle/general.py#L149Now, this is how
axis
value is being set at various framework native APIs:for
torch
: AtPytorch
gather we use a helper called : _validate_dimfor
TensorFlow
: for tf.gather we use a helper called get_positive_axisFor paddle.gather we set
axis = 0
Environment
Linux + Docker + vscode
Ivy Version
1.1.9
Backend
Device
CPU