Open zheyuye opened 4 years ago
@ZheyuYe In terms of mx.npx.pick
we can support multiple-dimensional selection as long as we are selecting the leading ones:
import mxnet as mx
import mxnet.numpy as np
mx.npx.set_np()
a = mx.np.random.uniform(-1, 1, (10, 10, 5))
b = np.random.randint(0, 5, (10,)).broadcast_to((10, 10))
mx.npx.pick(a, b)
Also, would you provide a reproducible code snippet (like above) to mention which type of indexing operation is missing?
@sxjscience It would be great if hybridized indexing could be supported as the last line in the code snippet.
import numpy as np
import mxnet as mx
import numpy.testing as npt
# new implementation in deep numpy
mx.npx.set_np()
sequence = mx.np.array(np.random.normal(0, 1, (8, 32, 768)), dtype=np.float32)
# pick_ids: [batch_size, picked_index]
pick_ids = mx.np.random.randint(0, 31, (8,2), dtype=np.int32)
idx_arange = mx.npx.arange_like(pick_ids.reshape((-1, )), axis=0)
batch_idx = mx.np.floor(idx_arange / 2).astype(np.int32)
encoded = sequence[batch_idx, pick_ids.reshape((-1,))]
I was aimed to pick the items from the sequence with shape (8, 2, 768) whereas the mn.npx.pick
can not handle it. Under the deep numpy enviorment, I used the basic indexing as numpy operation as encoded = sequence[batch_idx, pick_ids.reshape((-1,))]
which would fail after hybridize() raising the below Error
IndexError: Only integer, slice, or tuple of these types are supported! Received key=(<_Symbol albertmodel0_floor0>, <_Symbol albertmodel0_reshape4>)
The full testing code can be found in here
Description
Indexing is an important feature of numpy that supports complex positioning and value operations as Numpy | Indexing. The mxnet deepnumpy only supply the basic indexing operation as
whereas the above operation will fail after hybridize, see here. Currently, the more effective method is only mx.npx.pick, but it only supports one-dimensional selection instead of providing gather_nd operation( for nd.ndarray) can select array elements multidimensionally.
The temporary solution would be adding the gather_nd operation for mx.np.ndarray, and the hybridized indexing opertaion could be added in future work. @sxjscience @haojin2