Closed siddvenk closed 2 years ago
@siddvenk Thanks for spotting this issue! I found the root reason. It worked in version 0.17.0 because
NDArray keys = xTile.get((manager.eye(nTrain)).reshape(new Shape(nTrain, -1)));
internally calls take
(see PR) which is also supported in MXNet (see PR. In later versions, it switched back to indexing with NDIndex
. To utilize take
feature, take
has to be explicitly called now.
I will also add type convertion for indexing with NDIndex
in the current version too.
Awesome, thanks for figuring out the issue @KexinFeng !
Description
There seems to be backwards incompatible behavior with the NDArray.get(NDArray) method. In DJLv0.17.0 the following code works as expected, but in DJLv0.18.0 it throws an IllegalArgumentException at the
xTile.get(...)
line.Expected Behavior
The above code should work and return an NDArray (shape 1,5) with the diagonal elements from xTile.
Error Message
How to Reproduce?
Here's the code I'm running. To reproduce the issue run this against either master or v0.18 tag. If you run this against v0.17 tag it works as expected.
I added this code to a file in the examples/inference module and ran it via
./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug
Steps to reproduce
./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug
from the examples directoryWorking on v0.17: same as above but checkout tags/v0.17.0
What have you tried to solve it?
Seems like the logic here explains why this is throwing an error https://github.com/deepjavalibrary/djl/blob/e547f7144dbc4862f8081556a8aa9a0f757d4e9b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java#L356-L362. But this logic was roughly the same in v0.17 and worked fine.
I'm not sure what changed, but maybe we need to investigate whether we create NDArrays with different datatypes (like int) in some default cases like eye?
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: