deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.17k stars 663 forks source link

Model load failed #1663

Closed freemanliu closed 2 years ago

freemanliu commented 2 years ago

Description

This is reproduceable with the following test case: @Test fun testModelLoad() { var model = Model.newInstance("model") model.block = Mlp(2, 1, intArrayOf(10)) model.newTrainer(DefaultTrainingConfig(Loss.l2Loss())).use { trainer -> trainer.initialize(Shape(2)) val manager = model.ndManager; val input = manager.ones(Shape(1, 2), DataType.FLOAT32) val label = manager.create(floatArrayOf(0.5f)) val trainingDs = ArrayDataset.Builder().setData(input) .optLabels(label).setSampling(1, false).build() EasyTrain.fit(trainer, 100, trainingDs, trainingDs) model.save(Path.of("/tmp"), "predictorAndTrainer") } val model2 = Model.newInstance("model") model2.load(Path.of("/tmp"), "model") val p2 = model2.newPredictor(NoopTranslator()) NDManager.newBaseManager().use { manager -> println(p2.predict(NDList(manager.ones(Shape(1, 2))))) } }

Here is the gradle dependency to use pytorch engine. implementation 'ai.djl:basicdataset:0.17.0' implementation 'ai.djl:model-zoo:0.17.0' implementation 'ai.djl.pytorch:pytorch-model-zoo:0.17.0'

Further investigation shows that the save() is done in BaseModel while the load is done in PtModel. I was expecting the save() is also done in PtModel.

Expected Behavior

model.load succeeds.

Error Message

model.pt file not found in: /tmp java.io.FileNotFoundException: model.pt file not found in: /tmp at ai.djl.pytorch.engine.PtModel.load(PtModel.java:74) at ai.djl.Model.load(Model.java:121) at helloworld.jdl.AppTest.testModelLoad(AppTest.kt:79) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.runTestClass(JUnitTestClassExecutor.java:110) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:58) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:38) at org.gradle.api.internal.tasks.testing.junit.AbstractJUnitTestClassProcessor.processTestClass(AbstractJUnitTestClassProcessor.java:62) at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.processTestClass(SuiteTestClassProcessor.java:51) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24) at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33) at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94) at com.sun.proxy.$Proxy2.processTestClass(Unknown Source) at org.gradle.api.internal.tasks.testing.worker.TestWorker$2.run(TestWorker.java:176) at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129) at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100) at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60) at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56) at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:133) at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:71) at worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69) at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)

How to Reproduce?

(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)

Steps to reproduce

(Paste the commands you ran that produced the error.)

What have you tried to solve it?

1. 2.

Environment Info

Please run the command ./gradlew debugEnv from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below:

> Task :api:compileJava UP-TO-DATE
> Task :api:processResources UP-TO-DATE
> Task :api:classes UP-TO-DATE
> Task :api:jar UP-TO-DATE
> Task :basicdataset:compileJava UP-TO-DATE
> Task :basicdataset:processResources UP-TO-DATE
> Task :basicdataset:classes UP-TO-DATE
> Task :basicdataset:jar UP-TO-DATE
> Task :model-zoo:compileJava UP-TO-DATE
> Task :testing:compileJava UP-TO-DATE
> Task :integration:compileJava UP-TO-DATE
> Task :integration:processResources UP-TO-DATE
> Task :integration:classes UP-TO-DATE
> Task :model-zoo:processResources UP-TO-DATE
> Task :model-zoo:classes UP-TO-DATE
> Task :model-zoo:jar UP-TO-DATE
> Task :testing:processResources NO-SOURCE
> Task :testing:classes UP-TO-DATE
> Task :testing:jar UP-TO-DATE
> Task :engines:ml:xgboost:compileJava UP-TO-DATE
> Task :engines:ml:xgboost:processResources UP-TO-DATE
> Task :engines:ml:xgboost:classes UP-TO-DATE
> Task :engines:ml:xgboost:jar UP-TO-DATE
> Task :engines:mxnet:jnarator:generateGrammarSource UP-TO-DATE
> Task :engines:mxnet:jnarator:compileJava UP-TO-DATE
> Task :engines:mxnet:jnarator:processResources UP-TO-DATE
> Task :engines:mxnet:jnarator:classes UP-TO-DATE
> Task :engines:mxnet:jnarator:jar UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:jnarator UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:compileJava UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:processResources UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:classes UP-TO-DATE
> Task :engines:mxnet:mxnet-engine:jar UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:compileJava UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:processResources UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:classes UP-TO-DATE
> Task :engines:mxnet:mxnet-model-zoo:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:compileJava UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-engine:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:compileJava NO-SOURCE
> Task :engines:pytorch:pytorch-jni:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-jni:jar UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:compileJava UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:processResources UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:classes UP-TO-DATE
> Task :engines:pytorch:pytorch-model-zoo:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:compileJava NO-SOURCE
> Task :engines:tensorflow:tensorflow-api:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-api:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:compileJava UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-engine:jar UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:compileJava UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:processResources UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:classes UP-TO-DATE
> Task :engines:tensorflow:tensorflow-model-zoo:jar UP-TO-DATE

> Task :integration:debugEnv
[DEBUG] - Registering EngineProvider: XGBoost
[DEBUG] - Registering EngineProvider: MXNet
[DEBUG] - Registering EngineProvider: PyTorch
[DEBUG] - Registering EngineProvider: TensorFlow
[DEBUG] - Found default engine: MXNet
----------- System Properties -----------
gopherProxySet: false
awt.toolkit: sun.lwawt.macosx.LWCToolkit
java.specification.version: 11
sun.cpu.isalist: 
sun.jnu.encoding: UTF-8
java.class.path: /Users/freeman.liu/codes/djl/integration/build/classes/java/main:/Users/freeman.liu/codes/djl/integration/build/resources/main:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/commons-cli/commons-cli/1.5.0/dc98be5d5390230684a092589d70ea76a147925c/commons-cli-1.5.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-slf4j-impl/2.17.2/183f7c95fc981f3e97d008b363341343508848e/log4j-slf4j-impl-2.17.2.jar:/Users/freeman.liu/codes/djl/basicdataset/build/libs/basicdataset-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/model-zoo/build/libs/model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/testing/build/libs/testing-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.testng/testng/7.5/1416a607fae667c14e390b484e8d02b5824c0674/testng-7.5.jar:/Users/freeman.liu/codes/djl/engines/mxnet/mxnet-model-zoo/build/libs/mxnet-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-model-zoo/build/libs/pytorch-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-jni/build/libs/pytorch-jni-1.11.0-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-model-zoo/build/libs/tensorflow-model-zoo-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/ml/xgboost/build/libs/xgboost-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/mxnet/mxnet-engine/build/libs/mxnet-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/pytorch/pytorch-engine/build/libs/pytorch-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-engine/build/libs/tensorflow-engine-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/codes/djl/api/build/libs/api-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.slf4j/slf4j-api/1.7.36/6c62681a2f655b49963a5983b8b0950a6120ae14/slf4j-api-1.7.36.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-core/2.17.2/fa43ba4467f5300b16d1e0742934149bfc5ac564/log4j-core-2.17.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.logging.log4j/log4j-api/2.17.2/f42d6afa111b4dec5d2aea0fe2197240749a4ea6/log4j-api-2.17.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-csv/1.9.0/b59d8f64cd0b83ee1c04ff1748de2504457018c1/commons-csv-1.9.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.code.findbugs/jsr305/3.0.1/f7be08ec23c21485b9b5a1cf1654c2ec8c58168d/jsr305-3.0.1.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.beust/jcommander/1.78/a3927de9bd6f351429bcf763712c9890629d8f51/jcommander-1.78.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.webjars/jquery/3.5.1/2392938e374f561c27c53872bdc9b6b351b6ba34/jquery-3.5.1.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/ml.dmlc/xgboost4j_2.12/1.6.0/4623e78f614c998b4600c1cc58441ce06d80ba49/xgboost4j_2.12-1.6.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/commons-logging/commons-logging/1.2/4bfc12adfe4842bf07b657f0369c4cb522955686/commons-logging-1.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.code.gson/gson/2.9.0/8a1167e089096758b49f9b34066ef98b2f4b37aa/gson-2.9.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/net.java.dev.jna/jna/5.10.0/7cf4c87dd802db50721db66947aa237d7ad09418/jna-5.10.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.apache.commons/commons-compress/1.21/4ec95b60d4e86b5c95a0e919cb172a0af98011ef/commons-compress-1.21.jar:/Users/freeman.liu/codes/djl/engines/tensorflow/tensorflow-api/build/libs/tensorflow-api-0.18.0-SNAPSHOT.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.tensorflow/tensorflow-core-api/0.4.0/2ac35ca087607cce0e5419953cc1ef0c3a5edaea/tensorflow-core-api-0.4.0.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.bytedeco/javacpp/1.5.6/1f18a820aadd943577b0b372554f9e35e1232e25/javacpp-1.5.6.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/com.google.protobuf/protobuf-java/3.19.2/e958ce38f96b612d3819ff1c753d4d70609aea74/protobuf-java-3.19.2.jar:/Users/freeman.liu/.gradle/caches/modules-2/files-2.1/org.tensorflow/ndarray/0.3.3/1b6d8cc3e3762f6e465b884580d9fc17ab7aeb4/ndarray-0.3.3.jar
java.vm.vendor: AdoptOpenJDK
sun.arch.data.model: 64
user.variant: 
java.vendor.url: https://adoptopenjdk.net/
user.timezone: Australia/Sydney
os.name: Mac OS X
java.vm.specification.version: 11
sun.java.launcher: SUN_STANDARD
user.country: AU
sun.boot.library.path: /Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home/lib:/Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home/lib
sun.java.command: ai.djl.integration.util.DebugEnvironment
jdk.debug: release
sun.cpu.endian: little
user.home: /Users/freeman.liu
org.gradle.appname: gradlew
user.language: en
java.specification.vendor: Oracle Corporation
java.version.date: 2021-04-20
java.home: /Library/Java/JavaVirtualMachines/adoptopenjdk-11.jdk/Contents/Home
ai.djl.logging.level: debug
org.gradle.internal.http.connectionTimeout: 60000
file.separator: /
java.vm.compressedOopsMode: Zero based
line.separator: 

java.specification.name: Java Platform API Specification
java.vm.specification.vendor: Oracle Corporation
java.awt.graphicsenv: sun.awt.CGraphicsEnvironment
sun.management.compiler: HotSpot 64-Bit Tiered Compilers
java.runtime.version: 11.0.11+9
user.name: freeman.liu
path.separator: :
os.version: 11.5
java.runtime.name: OpenJDK Runtime Environment
file.encoding: UTF-8
java.vm.name: OpenJDK 64-Bit Server VM
java.vendor.version: AdoptOpenJDK-11.0.11+9
java.vendor.url.bug: https://github.com/AdoptOpenJDK/openjdk-support/issues
java.io.tmpdir: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T/
org.gradle.internal.http.socketTimeout: 120000
java.version: 11.0.11
user.dir: /Users/freeman.liu/codes/djl/integration
os.arch: x86_64
java.vm.specification.name: Java Virtual Machine Specification
java.awt.printerjob: sun.lwawt.macosx.CPrinterJob
sun.os.patch.level: unknown
java.library.path: /Users/freeman.liu/Library/Java/Extensions:/Library/Java/Extensions:/Network/Library/Java/Extensions:/System/Library/Java/Extensions:/usr/lib/java:.
java.vm.info: mixed mode
java.vendor: AdoptOpenJDK
java.vm.version: 11.0.11+9
sun.io.unicode.encoding: UnicodeBig
library.jansi.path: /Users/freeman.liu/.gradle/native/jansi/1.18/osx
java.class.version: 55.0
org.gradle.internal.publish.checksums.insecure: true

--------- Environment Variables ---------
PATH: /usr/local/opt/node@16/bin:/Users/freeman.liu/.amplify/bin:/usr/local/opt/node@12/bin:/Users/freeman.liu/bin:/usr/local/opt/node@16/bin:/Users/freeman.liu/.amplify/bin:/Users/freeman.liu/bin:/usr/local/opt/node@16/bin:/Users/freeman.liu/.cargo/bin:/Users/freeman.liu/.amplify/bin:/usr/local/opt/node@12/bin:/Users/freeman.liu/bin:/usr/local/bin:/usr/bin:/bin:/usr/sbin:/sbin:/usr/local/go/bin:/opt/puppetlabs/pdk/bin:/Library/Apple/usr/bin:/usr/local/bin:/Users/freeman.liu/Library/Python/3.9/bin:/usr/local/bin:/Users/freeman.liu/Library/Python/3.9/bin
APP_ICON_39014: /Users/freeman.liu/codes/djl/media/gradle.icns
APP_NAME_39014: Gradle
WORKON_HOME: /Users/freeman.liu/.virtualenvs
TERM: screen-bce
LANG: en_AU.UTF-8
VIRTUALENVWRAPPER_SCRIPT: /usr/local/bin/virtualenvwrapper.sh
VIRTUALENVWRAPPER_WORKON_CD: 1
STY: 36507.ttys000.MREM277DB4AD
LOGNAME: freeman.liu
XPC_SERVICE_NAME: 0
PWD: /Users/freeman.liu/codes/djl
TERM_PROGRAM_VERSION: 440
JAVA_MAIN_CLASS_39026: ai.djl.integration.util.DebugEnvironment
__CFBundleIdentifier: com.apple.Terminal
SHELL: /usr/local/bin/bash
TERM_PROGRAM: Apple_Terminal
SECURITYSESSIONID: 186aa
OLDPWD: /Users/freeman.liu/codes/djl
VIRTUALENVWRAPPER_HOOK_DIR: /Users/freeman.liu/.virtualenvs
USER: freeman.liu
WINDOW: 7
LaunchInstanceID: 57D413AC-7DB9-4C1A-BE2D-EE29A69C8716
TMPDIR: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T/
SSH_AUTH_SOCK: /private/tmp/com.apple.launchd.dnFo9uppka/Listeners
XPC_FLAGS: 0x0
LIBTORCH: /Users/freeman.liu/libtorch
TERM_SESSION_ID: 73D5F5B6-416C-42D9-B8D1-A124365689AE
VIRTUALENVWRAPPER_PROJECT_FILENAME: .project
TERMCAP: SC|screen-bce|VT 100/ANSI X3.64 virtual terminal:DO=\E[%dB:LE=\E[%dD:RI=\E[%dC:UP=\E[%dA:bs:bt=\E[Z:cd=\E[J:ce=\E[K:cl=\E[H\E[J:cm=\E[%i%d;%dH:ct=\E[3g:do=^J:nd=\E[C:pt:rc=\E8:rs=\Ec:sc=\E7:st=\EH:up=\EM:le=^H:bl=^G:cr=^M:it#8:ho=\E[H:nw=\EE:ta=^I:is=\E)0:li#45:co#178:am:xn:xv:LP:sr=\EM:al=\E[L:AL=\E[%dL:cs=\E[%i%d;%dr:dl=\E[M:DL=\E[%dM:dc=\E[P:DC=\E[%dP:im=\E[4h:ei=\E[4l:mi:IC=\E[%d@:ks=\E[?1h\E=:ke=\E[?1l\E>:vi=\E[?25l:ve=\E[34h\E[?25h:vs=\E[34l:ti=\E[?1049h:te=\E[?1049l:us=\E[4m:ue=\E[24m:so=\E[3m:se=\E[23m:mb=\E[5m:md=\E[1m:mr=\E[7m:me=\E[m:ms:Co#8:pa#64:AF=\E[3%dm:AB=\E[4%dm:op=\E[39;49m:AX:vb=\Eg:G0:as=\E(0:ae=\E(B:ac=\140\140aaffggjjkkllmmnnooppqqrrssttuuvvwwxxyyzz{{||}}~~..--++,,hhII00:po=\E[5i:pf=\E[4i:Km=\E[M:k0=\E[10~:k1=\EOP:k2=\EOQ:k3=\EOR:k4=\EOS:k5=\E[15~:k6=\E[17~:k7=\E[18~:k8=\E[19~:k9=\E[20~:k;=\E[21~:F1=\E[23~:F2=\E[24~:kB=\E[Z:kh=\E[1~:@1=\E[1~:kH=\E[4~:@7=\E[4~:kN=\E[6~:kP=\E[5~:kI=\E[2~:kD=\E[3~:ku=\EOA:kd=\EOB:kr=\EOC:kl=\EOD:km:
__CF_USER_TEXT_ENCODING: 0x1F6:0x0:0xF
PROJECT_HOME: /Users/freeman.liu/dev
JAVA_MAIN_CLASS_39014: org.gradle.wrapper.GradleWrapperMain
HOME: /Users/freeman.liu
SHLVL: 2

-------------- Directories --------------
temp directory: /var/folders/4r/gxktgtgn2277w32wxlm56t080000gp/T
DJL cache directory: /Users/freeman.liu/.djl.ai
Engine cache directory: /Users/freeman.liu/.djl.ai

------------------ CUDA -----------------
[DEBUG] - cudart library not found.
GPU Count: 0

----------------- Engines ---------------
DJL version: 0.18.0
Default Engine: MXNet
[DEBUG] - Using cache dir: /Users/freeman.liu/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64
[DEBUG] - Loading mxnet library from: /Users/freeman.liu/.djl.ai/mxnet/1.9.0-mkl-osx-x86_64/libmxnet.dylib
Default Device: cpu()
PyTorch: 2
MXNet: 0
XGBoost: 10
TensorFlow: 3

--------------- Hardware --------------
Available processors (cores): 12
Byte Order: LITTLE_ENDIAN
Free memory (bytes): 240247000
Maximum memory (bytes): 4294967296
Total memory available to JVM (bytes): 268435456
Heap committed: 268435456
Heap nonCommitted: 30474240
GCC: 
Apple clang version 13.0.0 (clang-1300.0.29.30)
Target: x86_64-apple-darwin20.6.0
Thread model: posix
InstalledDir: /Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/bin

BUILD SUCCESSFUL in 2s
44 actionable tasks: 1 executed, 43 up-to-date
freemanliu commented 2 years ago

Please fix this asap. This is a deal break. I won't be able to load the trained model. I've to turn to other framework.

frankfliu commented 2 years ago

@freemanliu

Sorry for the delay.

When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the Block before you load the model:

        model2.block = Mlp(2, 1, intArrayOf(10))
        # the model prefix you provide was also wrong in your code, it should be:
        model2.load(Path.of("/tmp"), "predictorAndTrainer")
freemanliu commented 2 years ago

Hi, Frank,

Thanks for looking into this. Added the block and the issue is the same. As I put before, the load and save implementation does not match. Are you able to make this piece of code work?

Cheers, Freeman

On Thu, Jun 2, 2022 at 7:05 AM Frank Liu @.***> wrote:

@freemanliu https://github.com/freemanliu

Sorry for the delay.

When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the Block before you load the model:

    model2.block = Mlp(2, 1, intArrayOf(10))
    # the model prefix you provide was also wrong in your code, it should be:
    model2.load(Path.of("/tmp"), "predictorAndTrainer")

— Reply to this email directly, view it on GitHub https://github.com/deepjavalibrary/djl/issues/1663#issuecomment-1144134684, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJKX36ESDCVC5OPUQAKOFRTVM7GBNANCNFSM5WIOWTGA . You are receiving this because you were mentioned.Message ID: @.***>

-- Language? Kotlin, Typescript or Rust? All of them!

frankfliu commented 2 years ago

@freemanliu I tested your code in java, and it's working:

    public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
        System.setProperty("ai.djl.default_engine", "PyTorch");
        Block mlp = new Mlp(2, 1, new int[] {10});
        Model model = Model.newInstance("model");
        model.setBlock(mlp);
        Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss()));
        trainer.initialize(new Shape(2));
        NDManager manager = model.getNDManager();
        NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32);
        NDArray label = manager.create(new float[] {0.5f});
        ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input)
                .optLabels(label).setSampling(1, false).build();
        EasyTrain.fit(trainer, 100, trainingDs, trainingDs);

        Path dir = Paths.get("build/mlp");
        Files.createDirectories(dir);
        model.save(dir, "predictorAndTrainer");

        Model model2 = Model.newInstance("model");
        model2.setBlock(mlp);
        model2.load(dir, "predictorAndTrainer");

        Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator());
        NDManager manager2 = NDManager.newBaseManager();
        NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2))));
        System.out.println(output.get(0));
    }

The output is:

ND: (1, 1) cpu() float32
[[0.4958],
]
freemanliu commented 2 years ago

Hi, Frank,

Thanks a lot for that!

I added the block and it still does not work. Following the code and I found the bug! in ai.djl.util.Utils.getCurrentEpoch at line 246:

Files.walk(modelDir, 1)

It does not look into a modelDir if it is a symlink. Adding FileVisitOption.FOLLOW_LINK should fix it.

Cheers,

Freeman

On Thu, Jun 2, 2022 at 1:40 PM Frank Liu @.***> wrote:

@freemanliu https://github.com/freemanliu I tested your code in java, and it's working:

public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
    System.setProperty("ai.djl.default_engine", "PyTorch");
    Block mlp = new Mlp(2, 1, new int[] {10});
    Model model = Model.newInstance("model");
    model.setBlock(mlp);
    Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss()));
    trainer.initialize(new Shape(2));
    NDManager manager = model.getNDManager();
    NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32);
    NDArray label = manager.create(new float[] {0.5f});
    ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input)
            .optLabels(label).setSampling(1, false).build();
    EasyTrain.fit(trainer, 100, trainingDs, trainingDs);

    Path dir = Paths.get("build/mlp");
    Files.createDirectories(dir);
    model.save(dir, "predictorAndTrainer");

    Model model2 = Model.newInstance("model");
    model2.setBlock(mlp);
    model2.load(dir, "predictorAndTrainer");

    Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator());
    NDManager manager2 = NDManager.newBaseManager();
    NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2))));
    System.out.println(output.get(0));
}

The output is:

ND: (1, 1) cpu() float32 [[0.4958], ]

— Reply to this email directly, view it on GitHub https://github.com/deepjavalibrary/djl/issues/1663#issuecomment-1144400991, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJKX36DJHQDYOQADB5U56BDVNAUJ5ANCNFSM5WIOWTGA . You are receiving this because you were mentioned.Message ID: @.***>

-- Language? Kotlin, Typescript or Rust? All of them!

frankfliu commented 2 years ago

@freemanliu

Since we only look 1 level of the directory, FileVisitOption.FOLLOW_LINK should work here. Would you mind raise an PR to improve this?

frankfliu commented 2 years ago

Fixed by #1692