ivy-llc / ivy

Convert Machine Learning Code Between Frameworks
https://ivy.dev
Other
14.02k stars 5.77k forks source link

[Bug]: Inconsistent `axis` Normalization in `ivy.gather` Across Backends #23044

Open akshatvishu opened 1 year ago

akshatvishu commented 1 year ago

Bug Explanation

The current implementation of the gather function in Ivy normalizes the axis parameter differently than native implementations in frameworks like TensorFlow , PyTorch or Paddle. Specifically, the line axis = axis % len(params.shape) may not be in line with howTensorFlow , PyTorch or Paddle handle axis normalization.

Steps to Reproduce Bug

we're setting the value of axis to be: axis = axis % len(params.shape) across all backend at ivy For jax backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/jax/general.py#L123

For torch backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/torch/general.py#L174 For tensorflow backend: https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/tensorflow/general.py#L103 For NumPy backend:https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/numpy/general.py#L75 For Paddle backend:https://github.com/unifyai/ivy/blob/16d690d318bb0044378be534cf621b2c4bb9160d/ivy/functional/backends/paddle/general.py#L149

Now, this is how axis value is being set at various framework native APIs:

for torch: At Pytorch gather we use a helper called : _validate_dim

for TensorFlow: for tf.gather we use a helper called get_positive_axis

For paddle.gather we set axis = 0

Environment

Linux + Docker + vscode

Ivy Version

1.1.9

Backend

Device

CPU

AnnaTz commented 1 year ago

Actually we are defaulting axis=-1, i.e. the last axis, while axis = axis % len(params.shape) just makes sure the axis is within range (I guess something similar to _validate_dim and get_positive_axis). torch.gather doesn't default dim at all, tf.gather and paddle.gather defaults axis=0. I agree we should change the behavior of ivy.gather to match one of these frameworks.

akshatvishu commented 1 year ago

Actually we are defaulting axis=-1, i.e. the last axis, while axis = axis % len(params.shape) just makes sure the axis is within range (I guess something similar to _validate_dim and get_positive_axis). torch.gather doesn't default dim at all, tf.gather and paddle.gather defaults axis=0. I agree we should change the behavior of ivy.gather to match one of these frameworks.

can we just use TF's get_positive_axis and define it at ivy/utils/assertions.py and just call it like :

# path: ivy/utils/assertions.py

def get_positive_axis_for_gather(axis,ndims):
    if not isinstance(axis, int):
        raise TypeError(f"{axis} must be an int; got {type(axis).__name__}")
    if ndims is not None:
        if 0 <= axis < ndims:
            return axis
        elif -ndims <= axis < 0:
            return axis + ndims
        else:
            raise ValueError(f"{axis}={axis} out of bounds: "
                        f"expected {-ndims}<={axis}<{ndims}")
    elif axis < 0:
        raise ValueError(f"{axis} may only be negative "
                        f"if {ndims} is statically known.")
    return axis

# path: ivy/functional/backends/jax/general.py

def gather(
    params: JaxArray,
    indices: JaxArray,
    /,
    *,
    axis: int = -1,
    batch_dims: int = 0,
    out: Optional[JaxArray] = None,
) -> JaxArray:
    axis =  ivy.utils.assertions.get_positive_axis_for_gather(axis, params.ndim)#axis = axis % len(params.shape)
    batch_dims = batch_dims % len(params.shape)
    ivy.utils.assertions.check_gather_input_valid(params, indices, axis, batch_dims)
    result = []
    if batch_dims == 0:
        result = jnp.take(params, indices, axis)
    else:
        for b in range(batch_dims):
            if b == 0:
                zip_list = [(p, i) for p, i in zip(params, indices)]
            else:
                zip_list = [
                    (p, i) for z in [zip(p1, i1) for p1, i1 in zip_list] for p, i in z
                ]
        for z in zip_list:
            p, i = z
            r = jnp.take(p, i, axis - batch_dims)
            result.append(r)
        result = jnp.array(result)
        result = result.reshape([*params.shape[0:batch_dims], *result.shape[1:]])
    return result

"""
and do the same for rest of the backends

OR

can define this helper at each backend and then call it
"""
AnnaTz commented 1 year ago

I think it would make more sense to have _get_positive_axis under ivy/functional/general.py and import it to all the backends from there. It's not a pure assertion after all, it does modify the given axis in most cases.

AnnaTz commented 1 year ago

Other than that, we would need to change the axis default value and make sure the functions that use ivy.gather don't start failing.

akshatvishu commented 1 year ago

Other than that, we would need to change the axis default value and make sure the functions that use ivy.gather don't start failing.

I will test it and only push a PR if all the test are passing.

Aryan8912 commented 1 year ago

please assign to me

akshatvishu commented 1 year ago

please assign to me

Sure, feel free to work on it!