apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.78k stars 6.79k forks source link

np.ndarray indexing after hybridize #17327

Open zheyuye opened 4 years ago

zheyuye commented 4 years ago

Description

Indexing is an important feature of numpy that supports complex positioning and value operations as Numpy | Indexing. The mxnet deepnumpy only supply the basic indexing operation as

gathered_data = sequence[indices_x, indices_y]

whereas the above operation will fail after hybridize, see here. Currently, the more effective method is only mx.npx.pick, but it only supports one-dimensional selection instead of providing gather_nd operation( for nd.ndarray) can select array elements multidimensionally.

The temporary solution would be adding the gather_nd operation for mx.np.ndarray, and the hybridized indexing opertaion could be added in future work. @sxjscience @haojin2

sxjscience commented 4 years ago

@ZheyuYe In terms of mx.npx.pick we can support multiple-dimensional selection as long as we are selecting the leading ones:

import mxnet as mx
import mxnet.numpy as np
mx.npx.set_np()
a = mx.np.random.uniform(-1, 1, (10, 10, 5))
b = np.random.randint(0, 5, (10,)).broadcast_to((10, 10))
mx.npx.pick(a, b) 

Also, would you provide a reproducible code snippet (like above) to mention which type of indexing operation is missing?

zheyuye commented 4 years ago

@sxjscience It would be great if hybridized indexing could be supported as the last line in the code snippet.

import numpy as np
import mxnet as mx
import numpy.testing as npt

# new implementation in deep numpy
mx.npx.set_np()
sequence = mx.np.array(np.random.normal(0, 1, (8, 32, 768)), dtype=np.float32)
# pick_ids: [batch_size, picked_index]
pick_ids = mx.np.random.randint(0, 31, (8,2), dtype=np.int32)

idx_arange = mx.npx.arange_like(pick_ids.reshape((-1, )), axis=0)
batch_idx = mx.np.floor(idx_arange / 2).astype(np.int32)

encoded = sequence[batch_idx, pick_ids.reshape((-1,))]

I was aimed to pick the items from the sequence with shape (8, 2, 768) whereas the mn.npx.pick can not handle it. Under the deep numpy enviorment, I used the basic indexing as numpy operation as encoded = sequence[batch_idx, pick_ids.reshape((-1,))] which would fail after hybridize() raising the below Error

IndexError: Only integer, slice, or tuple of these types are supported! Received key=(<_Symbol albertmodel0_floor0>, <_Symbol albertmodel0_reshape4>)

The full testing code can be found in here