deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.13k stars 655 forks source link

NDArray.get(NDArray) breaking behavior in 0.18.0 #1799

Closed siddvenk closed 2 years ago

siddvenk commented 2 years ago

Description

There seems to be backwards incompatible behavior with the NDArray.get(NDArray) method. In DJLv0.17.0 the following code works as expected, but in DJLv0.18.0 it throws an IllegalArgumentException at the xTile.get(...) line.

try (NDManager manager = NDManager.newBaseManager()) {
            int nTrain = 5;
            NDArray xTrain = manager.randomUniform(0, 1, new Shape(nTrain)).mul(5).sort();
            NDArray xTile = xTrain.tile(new long[] {nTrain, 1});
            // Following line works as expected in v0.17 and returns the diagonal elements
            // Throws IllegalArgumentException in v0.18
            NDArray keys = xTile.get((manager.eye(nTrain)).reshape(new Shape(nTrain, -1)));
}

Expected Behavior

The above code should work and return an NDArray (shape 1,5) with the diagonal elements from xTile.

Error Message

[ERROR] - {}
java.lang.IllegalArgumentException: Unknown argument: ND: (5, 5) cpu() float32
[[1., 0., 0., 0., 0.],
 [0., 1., 0., 0., 0.],
 [0., 0., 1., 0., 0.],
 [0., 0., 0., 1., 0.],
 [0., 0., 0., 0., 1.],
]

        at ai.djl.ndarray.index.NDIndex.addIndexItem(NDIndex.java:384) ~[api-0.19.0-SNAPSHOT.jar:?]
        at ai.djl.ndarray.index.NDIndex.addIndices(NDIndex.java:220) ~[api-0.19.0-SNAPSHOT.jar:?]
        at ai.djl.ndarray.index.NDIndex.<init>(NDIndex.java:130) ~[api-0.19.0-SNAPSHOT.jar:?]
        at ai.djl.ndarray.NDArray.get(NDArray.java:533) ~[api-0.19.0-SNAPSHOT.jar:?]
        at ai.djl.examples.bugs.NDArrayIndexingBug.main(NDArrayIndexingBug.java:32) ~[main/:?]

How to Reproduce?

Here's the code I'm running. To reproduce the issue run this against either master or v0.18 tag. If you run this against v0.17 tag it works as expected.

I added this code to a file in the examples/inference module and ran it via ./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class NDArrayIndexingBug {

    private static final Logger logger = LoggerFactory.getLogger(NDArrayIndexingBug.class);

    private NDArrayIndexingBug() {}

    public static void main(String[] args) {
        try (NDManager manager = NDManager.newBaseManager()) {
            int nTrain = 5;
            NDArray xTrain = manager.randomUniform(0, 1, new Shape(nTrain)).mul(5).sort();
            NDArray xTile = xTrain.tile(new long[] {nTrain, 1});
            NDArray keys = xTile.get((manager.eye(nTrain)).reshape(new Shape(nTrain, -1)));
            logger.info("keys: {}", keys);
        } catch (IllegalArgumentException e) {
            logger.error("{}", e);
        }

    }
}

Steps to reproduce

  1. checkout djl repo locally
  2. git checkout tags/v0.18.0
  3. Add the above code to a new file in examples/inference
  4. ./gradlew run -Dmain=ai.djl.examples.inference.NDArrayIndexingBug from the examples directory

Working on v0.17: same as above but checkout tags/v0.17.0

What have you tried to solve it?

Seems like the logic here explains why this is throwing an error https://github.com/deepjavalibrary/djl/blob/e547f7144dbc4862f8081556a8aa9a0f757d4e9b/api/src/main/java/ai/djl/ndarray/index/NDIndex.java#L356-L362. But this logic was roughly the same in v0.17 and worked fine.

I'm not sure what changed, but maybe we need to investigate whether we create NDArrays with different datatypes (like int) in some default cases like eye?

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

----------- System Properties -----------
gopherProxySet: false
awt.toolkit: sun.lwawt.macosx.LWCToolkit
java.specification.version: 11
sun.cpu.isalist: 
sun.jnu.encoding: UTF-8
java.class.path: /Volumes/workplace/djl/integration/build/classes/java/main:/Volumes/workplace/djl/integration/build/resources/main:/Users/siddhave/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.5.0/dc98be5d5390230684a092589d70ea76a147925c/commons-cli-1.5.0.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j-impl/2.17.2/183f7c95fc981f3e97d008b363341343508848e/log4j-slf4j-impl-2.17.2.jar:/Volumes/workplace/djl/basicdataset/build/libs/basicdataset-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/model-zoo/build/libs/model-zoo-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/testing/build/libs/testing-0.19.0-SNAPSHOT.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.5/1416a607fae667c14e390b484e8d02b5824c0674/testng-7.5.jar:/Volumes/workplace/djl/engines/mxnet/mxnet-model-zoo/build/libs/mxnet-model-zoo-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/pytorch/pytorch-model-zoo/build/libs/pytorch-model-zoo-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/pytorch/pytorch-jni/build/libs/pytorch-jni-1.11.0-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/tensorflow/tensorflow-model-zoo/build/libs/tensorflow-model-zoo-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/ml/xgboost/build/libs/xgboost-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/mxnet/mxnet-engine/build/libs/mxnet-engine-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/pytorch/pytorch-engine/build/libs/pytorch-engine-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/engines/tensorflow/tensorflow-engine/build/libs/tensorflow-engine-0.19.0-SNAPSHOT.jar:/Volumes/workplace/djl/api/build/libs/api-0.19.0-SNAPSHOT.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/1.7.36/6c62681a2f655b49963a5983b8b0950a6120ae14/slf4j-api-1.7.36.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.17.2/fa43ba4467f5300b16d1e0742934149bfc5ac564/log4j-core-2.17.2.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.17.2/f42d6afa111b4dec5d2aea0fe2197240749a4ea6/log4j-api-2.17.2.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.9.0/b59d8f64cd0b83ee1c04ff1748de2504457018c1/commons-csv-1.9.0.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/com.google.code.findbugs/jsr305/3.0.1/f7be08ec23c21485b9b5a1cf1654c2ec8c58168d/jsr305-3.0.1.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.78/a3927de9bd6f351429bcf763712c9890629d8f51/jcommander-1.78.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.webjars/jquery/3.5.1/2392938e374f561c27c53872bdc9b6b351b6ba34/jquery-3.5.1.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/ml.dmlc/xgboost4j_2.12/1.6.1/da6824d8e57dc3cf4f873bd926ca5a4c7f914603/xgboost4j_2.12-1.6.1.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/commons-logging/commons-logging/1.2/4bfc12adfe4842bf07b657f0369c4cb522955686/commons-logging-1.2.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.9.0/8a1167e089096758b49f9b34066ef98b2f4b37aa/gson-2.9.0.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.11.0/27770efb6329f092f895c7329662d1aa8ee8c0ac/jna-5.11.0.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.21/4ec95b60d4e86b5c95a0e919cb172a0af98011ef/commons-compress-1.21.jar:/Volumes/workplace/djl/engines/tensorflow/tensorflow-api/build/libs/tensorflow-api-0.19.0-SNAPSHOT.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-api/0.4.0/2ac35ca087607cce0e5419953cc1ef0c3a5edaea/tensorflow-core-api-0.4.0.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.bytedeco/javacpp/1.5.6/1f18a820aadd943577b0b372554f9e35e1232e25/javacpp-1.5.6.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/com.google.protobuf/protobuf-java/3.19.2/e958ce38f96b612d3819ff1c753d4d70609aea74/protobuf-java-3.19.2.jar:/Users/siddhave/.gradle/caches/modules-2/files-2.1/org.tensorflow/ndarray/0.3.3/1b6d8cc3e3762f6e465b884580d9fc17ab7aeb4/ndarray-0.3.3.jar
java.vm.vendor: Amazon.com Inc.
sun.arch.data.model: 64
user.variant: 
java.vendor.url: https://aws.amazon.com/corretto/
user.timezone: America/Los_Angeles
os.name: Mac OS X
java.vm.specification.version: 11
sun.java.launcher: SUN_STANDARD
user.country: US
sun.boot.library.path: /Library/Java/JavaVirtualMachines/amazon-corretto-11.jdk/Contents/Home/lib:/Library/Java/JavaVirtualMachines/amazon-corretto-11.jdk/Contents/Home/lib
sun.java.command: ai.djl.integration.util.DebugEnvironment
jdk.debug: release
sun.cpu.endian: little
user.home: /Users/siddhave
org.gradle.appname: gradlew
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2022-01-18
java.home: /Library/Java/JavaVirtualMachines/amazon-corretto-11.jdk/Contents/Home
ai.djl.logging.level: debug
org.gradle.internal.http.connectionTimeout: 60000
file.separator: /
java.vm.compressedOopsMode: Zero based
line.separator: 

java.specification.name: Java Platform API Specification
java.vm.specification.vendor: Oracle Corporation
java.awt.graphicsenv: sun.awt.CGraphicsEnvironment
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.runtime.version: 11.0.14+9-LTS
user.name: siddhave
path.separator: :
os.version: 12.3.1
java.runtime.name: OpenJDK Runtime Environment
file.encoding: UTF-8
java.vm.name: OpenJDK 64-Bit Server VM
java.vendor.version: Corretto-11.0.14.9.1
java.vendor.url.bug: https://github.com/corretto/corretto-11/issues/
java.io.tmpdir: /var/folders/n4/3sjz_43j19sf37w1h6fy1_w40000gr/T/
org.gradle.internal.http.socketTimeout: 120000
java.version: 11.0.14
user.dir: /Volumes/workplace/djl/integration
os.arch: x86_64
java.vm.specification.name: Java Virtual Machine Specification
java.awt.printerjob: sun.lwawt.macosx.CPrinterJob
sun.os.patch.level: unknown
java.library.path: /Users/siddhave/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
java.vm.info: mixed mode
java.vendor: Amazon.com Inc.
java.vm.version: 11.0.14+9-LTS
sun.io.unicode.encoding: UnicodeBig
library.jansi.path: /Users/siddhave/.gradle/native/jansi/1.18/osx
java.class.version: 55.0
org.gradle.internal.publish.checksums.insecure: true

--------- Environment Variables ---------
PATH: /opt/homebrew/bin:/opt/homebrew/sbin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/Users/siddhave/.toolbox/bin:/Users/siddhave/Library/Python/3.8/bin
MANPATH: /opt/homebrew/share/man::
APP_NAME_97995: Gradle
JAVA_HOME: /Library/Java/JavaVirtualMachines/amazon-corretto-11.jdk/Contents/Home
TERM: xterm-256color
LANG: en_US.UTF-8
HOMEBREW_PREFIX: /opt/homebrew
JAVA_MAIN_CLASS_98025: ai.djl.integration.util.DebugEnvironment
JAVA_MAIN_CLASS_97995: org.gradle.wrapper.GradleWrapperMain
APP_ICON_97995: /Volumes/workplace/djl/media/gradle.icns
LOGNAME: siddhave
HOMEBREW_REPOSITORY: /opt/homebrew
PWD: /Users/siddhave/workplace/djl
TERM_PROGRAM_VERSION: 444
XPC_SERVICE_NAME: 0
INFOPATH: /opt/homebrew/share/info:
__CFBundleIdentifier: com.apple.Terminal
SHELL: /bin/zsh
TERM_PROGRAM: Apple_Terminal
SECURITYSESSIONID: 186b1
HOMEBREW_CELLAR: /opt/homebrew/Cellar
OLDPWD: /Users/siddhave/workplace/djl
USER: siddhave
LaunchInstanceID: FC007782-4E31-4786-8C3E-364A398BDAE2
TMPDIR: /var/folders/n4/3sjz_43j19sf37w1h6fy1_w40000gr/T/
SSH_AUTH_SOCK: /private/tmp/com.apple.launchd.FEHiqBC00j/Listeners
XPC_FLAGS: 0x0
TERM_SESSION_ID: CE1A41BB-1885-4DFF-B1AC-2A5920C13B14
__CF_USER_TEXT_ENCODING: 0x1F8:0x0:0x0
SHLVL: 1
HOME: /Users/siddhave

-------------- Directories --------------
temp directory: /var/folders/n4/3sjz_43j19sf37w1h6fy1_w40000gr/T
DJL cache directory: /Users/siddhave/.djl.ai
Engine cache directory: /Users/siddhave/.djl.ai

------------------ CUDA -----------------
[DEBUG] - cudart library not found.
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.19.0
Default Engine: MXNet
[DEBUG] - Using cache dir: /Users/siddhave/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64
[DEBUG] - Loading mxnet library from: /Users/siddhave/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64/libmxnet.dylib
Default Device: cpu()
PyTorch: 2
MXNet: 0
XGBoost: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 10
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 513546752
Maximum memory (bytes): 8589934592
Total memory available to JVM (bytes): 536870912
Heap committed: 536870912
Heap nonCommitted: 31064064
GCC: 
Apple clang version 13.1.6 (clang-1316.0.21.2.5)
Target: x86_64-apple-darwin21.4.0
Thread model: posix
InstalledDir: /Library/Developer/CommandLineTools/usr/bin
KexinFeng commented 2 years ago

@siddvenk Thanks for spotting this issue! I found the root reason. It worked in version 0.17.0 because NDArray keys = xTile.get((manager.eye(nTrain)).reshape(new Shape(nTrain, -1))); internally calls take (see PR) which is also supported in MXNet (see PR. In later versions, it switched back to indexing with NDIndex. To utilize take feature, take has to be explicitly called now.

I will also add type convertion for indexing with NDIndex in the current version too.

siddvenk commented 2 years ago

Awesome, thanks for figuring out the issue @KexinFeng !