Closed sxjscience closed 4 years ago
Besides that, it would be great if npx.scatter_nd
is implemented for mx.np.ndarray
.
@ZheyuYe @JiangZhaoh @yzhliu @haojin2
To understand the problem, let's consider two use cases. The first one can be solved via gather_nd
and the second one cannot be solved via the existing MXNet.
out[i, j, ...] = data[i, positions[i, j], ...]
In GluonNLP, the positions
are masked locations in the input that we will need to calculate the loss. data
is the mapped hidden states of the sequences.
With advanced indexing + imperative API, we can do something like this:
import mxnet as mx
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
out = data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions]
print(out.asnumpy().shape)
In order to make the network hybridizable, we can implement it via gather_nd
:
@use_np
def select_vectors_by_position(F, data, positions):
"""Select each batch with the given positions.
Once advanced indexing can be hybridized, we can revise the implementation.
out[i, j, :] = data[i, positions[i, j], :]
Parameters
----------
F
data
Input tensor of contextualized token embeddings
Shape (batch_size, seq_length, units)
positions
Input tensor of the positions.
Shape (batch_size, num_sel_positions).
For each sample in the batch, the values in this tensor must not exceed
the length of the sequence.
Returns
-------
out
The selection result.
Shape (batch_size, num_sel_positions, units)
"""
# Here, we use gather_nd to select the output from data:
# Need to compute
# out[i, j, :] = in[i, masked_position[i, j], :]
# Thus, construct a indices with shape [2, batch_size, num_masked_position], where
# indices[0, i, j] = i
# indices[1, i, j] = masked_position[i, j]
# Then, out = gather_nd(in, indices)
positions = positions.astype(np.int32)
# batch_idx.shape = (batch_size, 1) as [[0], [1], [2], ...]
batch_idx = F.np.expand_dims(F.npx.arange_like(positions, axis=0),
axis=1).astype(np.int32)
batch_idx = batch_idx + F.np.zeros_like(positions)
indices = F.np.stack([batch_idx, positions])
out = F.npx.gather_nd(data, indices)
return out
For example, if we need some selected locations and will need to replace the elements without own generated element, i.e.,
data[i, positions[i, j], ...] = update_val[i, j, ...]
With advanced indexing + imperative API, we can do something like this:
import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
print(out.asnumpy().shape)
# or do
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] += update_val
print(out.asnumpy().shape)
However, we cannot surround it with autograd
import mxnet as mx
import numpy.testing as npt
mx.npx.set_np()
data = mx.np.random.normal(0, 1, (5, 5, 5, 5))
positions = mx.np.random.randint(0, 5, (5, 4))
update_val = mx.np.random.normal(0, 1, (5, 4, 5, 5))
data.attach_grad()
with mx.autograd.record():
data[mx.np.expand_dims(mx.npx.arange_like(data, axis=0), axis=-1), positions] = update_val
mx.npx.waitall()
Error message:
MXNetError: Traceback (most recent call last):
File "src/imperative/imperative.cc", line 203
MXNetError: Check failed: AGInfo: :IsNone(*output): Assigning to NDArrays that are already in a computational graph will cause undefined behavior when evaluating gradients. Please call backward first to clear the graph or do this out side of a record section. Also note that you cannot use inplace operations like +=, *=, relu(x, out=x), y[idx]=x, etc inside a record section.
We will need to have a workaround solution to this use case.
The first use case would be improved though #18319 that insipred by #17327.
We need the functionality to calculate
b = index_add(a, indices, value)
, which mimics the outcome ofa[indices] += value
.This is similar to the
tensor_scatter_nd_add
in TF: https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_addAlso in JAX: https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update