deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.01k stars 639 forks source link

Pick index not working for multidimensional arrays #2494

Open tipame opened 1 year ago

tipame commented 1 year ago

Description

NDArray#get fails with pick index for multidimensional arrays: Executing code:

NDArray target = manager.arange(6).reshape(3, 2);
NDArray index = manager.create(new long[] {0, 2});
NDArray result = target.get(new NDIndex().addPickDim(index));

Expected Behavior

Expect ndarray of shape 2x2 (as described in javadoc for NDIndex#addPickDim): [[0, 1], [4, 5]]

Error Message

java.lang.IllegalArgumentException: expand shape failed! Cannot expand from (2)to (3, 2)

at ai.djl.pytorch.jni.JniUtils.pick(JniUtils.java:618)
at ai.djl.pytorch.jni.JniUtils.indexAdv(JniUtils.java:464)
at ai.djl.pytorch.engine.PtNDArrayIndexer.get(PtNDArrayIndexer.java:74)
at ai.djl.ndarray.NDArray.get(NDArray.java:523)
at ai.djl.ndarray.NDArray.get(NDArray.java:512)
KexinFeng commented 1 year ago

No, the defination of addPickDim is aligned with https://mxnet.apache.org/versions/1.6/api/r/docs/api/mx.nd.pick.html. So the output of the code

NDArray target = manager.arange(6).reshape(3, 2);
NDArray pickIndex = manager.create(new long[] {0, 2}, new Shape(1, 2));
NDArray result = target.get(new NDIndex().addPickDim(pickIndex));

should be [[ 0, 5],]. This feature is not often used though.

To get [[0, 1], [4, 5]], you will need the array indexing.

NDArray index = manager.create(new long[] {0, 2});
NDArray ret = target.get(index);

Check out: https://github.com/deepjavalibrary/djl/blob/866be61a0cd8a75b98a23efef9dbf6cf13fac910/integration/src/main/java/ai/djl/integration/tests/ndarray/NDIndexTest.java#L153-L161