apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

[Operator] Add `index_add` or `index_update` to numpy extension #17823

Closed sxjscience closed 4 years ago

sxjscience commented 4 years ago

We need the functionality to calculate b = index_add(a, indices, value), which mimics the outcome of a[indices] += value.

This is similar to the tensor_scatter_nd_add in TF: https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_add

Also in JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update

zheyuye commented 4 years ago

Besides that, it would be great if npx.scatter_nd is implemented for mx.np.ndarray.

sxjscience commented 4 years ago

@ZheyuYe @JiangZhaoh @yzhliu @haojin2

To understand the problem, let's consider two use cases. The first one can be solved via gather_nd and the second one cannot be solved via the existing MXNet.

Take elements at specific locations from the input data

out[i, j, ...] = data[i, positions[i, j], ...]

In GluonNLP, the positions are masked locations in the input that we will need to calculate the loss. data is the mapped hidden states of the sequences.

With advanced indexing + imperative API, we can do something like this:

import mxnet as mx
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
out = data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions]
print(out.asnumpy().shape)

In order to make the network hybridizable, we can implement it via gather_nd:

@use_np
def select_vectors_by_position(F, data, positions):
    """Select each batch with the given positions.
    Once advanced indexing can be hybridized, we can revise the implementation.
    out[i, j, :] = data[i, positions[i, j], :]
    Parameters
    ----------
    F
    data
        Input tensor of contextualized token embeddings
        Shape (batch_size, seq_length, units)
    positions
        Input tensor of the positions.
        Shape (batch_size, num_sel_positions).
        For each sample in the batch, the values in this tensor must not exceed
        the length of the sequence.
    Returns
    -------
    out
        The selection result.
        Shape (batch_size, num_sel_positions, units)
    """
    # Here, we use gather_nd to select the output from data:
    # Need to compute
    #   out[i, j, :] = in[i, masked_position[i, j], :]
    # Thus, construct a indices with shape [2, batch_size, num_masked_position], where
    #     indices[0, i, j] = i
    #     indices[1, i, j] = masked_position[i, j]
    # Then, out = gather_nd(in, indices)
    positions = positions.astype(np.int32)
    # batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
    batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0),
                                 axis=1).astype(np.int32)
    batch_idx = batch_idx + F.np.zeros_like(positions)
    indices = F.np.stack([batch_idx, positions])
    out = F.npx.gather_nd(data, indices)
    return out

Update elements at specific locations of the input data

For example, if we need some selected locations and will need to replace the elements without own generated element, i.e.,

data[i, positions[i, j], ...] = update_val[i, j, ...]

With advanced indexing + imperative API, we can do something like this:

import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
print(out.asnumpy().shape)

# or do
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] += update_val
print(out.asnumpy().shape)

However, we cannot surround it with autograd


import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()

data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data.attach_grad()
with mx.autograd.record():
   data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
mx.npx.waitall()

Error message:

MXNetError: Traceback (most recent call last):
  File "src/imperative/imperative.cc", line 203
MXNetError: Check failed: AGInfo: :IsNone(*output): Assigning to NDArrays that are already in a computational graph will cause undefined behavior when evaluating gradients. Please call backward first to clear the graph or do this out side of a record section. Also note that you cannot use inplace operations like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section.

We will need to have a workaround solution to this use case.

zheyuye commented 4 years ago

The first use case would be improved though #18319 that insipred by #17327.