Closed freemanliu closed 2 years ago
Please fix this asap. This is a deal break. I won't be able to load the trained model. I've to turn to other framework.
@freemanliu
Sorry for the delay.
When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the Block
before you load the model:
model2.block = Mlp(2, 1, intArrayOf(10))
# the model prefix you provide was also wrong in your code, it should be:
model2.load(Path.of("/tmp"), "predictorAndTrainer")
Hi, Frank,
Thanks for looking into this. Added the block and the issue is the same. As I put before, the load and save implementation does not match. Are you able to make this piece of code work?
Cheers, Freeman
On Thu, Jun 2, 2022 at 7:05 AM Frank Liu @.***> wrote:
@freemanliu https://github.com/freemanliu
Sorry for the delay.
When you training a model in DJL, the trainer only save the model's parameters. The block information is not serialized in the model directory. In order to load such model you need to manually set the Block before you load the model:
model2.block = Mlp(2, 1, intArrayOf(10)) # the model prefix you provide was also wrong in your code, it should be: model2.load(Path.of("/tmp"), "predictorAndTrainer")
— Reply to this email directly, view it on GitHub https://github.com/deepjavalibrary/djl/issues/1663#issuecomment-1144134684, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJKX36ESDCVC5OPUQAKOFRTVM7GBNANCNFSM5WIOWTGA . You are receiving this because you were mentioned.Message ID: @.***>
-- Language? Kotlin, Typescript or Rust? All of them!
@freemanliu I tested your code in java, and it's working:
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException {
System.setProperty("ai.djl.default_engine", "PyTorch");
Block mlp = new Mlp(2, 1, new int[] {10});
Model model = Model.newInstance("model");
model.setBlock(mlp);
Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss()));
trainer.initialize(new Shape(2));
NDManager manager = model.getNDManager();
NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32);
NDArray label = manager.create(new float[] {0.5f});
ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input)
.optLabels(label).setSampling(1, false).build();
EasyTrain.fit(trainer, 100, trainingDs, trainingDs);
Path dir = Paths.get("build/mlp");
Files.createDirectories(dir);
model.save(dir, "predictorAndTrainer");
Model model2 = Model.newInstance("model");
model2.setBlock(mlp);
model2.load(dir, "predictorAndTrainer");
Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator());
NDManager manager2 = NDManager.newBaseManager();
NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2))));
System.out.println(output.get(0));
}
The output is:
ND: (1, 1) cpu() float32
[[0.4958],
]
Hi, Frank,
Thanks a lot for that!
I added the block and it still does not work. Following the code and I found the bug! in ai.djl.util.Utils.getCurrentEpoch at line 246:
Files.walk(modelDir, 1)
It does not look into a modelDir if it is a symlink. Adding FileVisitOption.FOLLOW_LINK should fix it.
Cheers,
Freeman
On Thu, Jun 2, 2022 at 1:40 PM Frank Liu @.***> wrote:
@freemanliu https://github.com/freemanliu I tested your code in java, and it's working:
public static void main(String[] args) throws IOException, TranslateException, MalformedModelException { System.setProperty("ai.djl.default_engine", "PyTorch"); Block mlp = new Mlp(2, 1, new int[] {10}); Model model = Model.newInstance("model"); model.setBlock(mlp); Trainer trainer = model.newTrainer(new DefaultTrainingConfig(Loss.l2Loss())); trainer.initialize(new Shape(2)); NDManager manager = model.getNDManager(); NDArray input = manager.ones(new Shape(1, 2), DataType.FLOAT32); NDArray label = manager.create(new float[] {0.5f}); ArrayDataset trainingDs = new ArrayDataset.Builder().setData(input) .optLabels(label).setSampling(1, false).build(); EasyTrain.fit(trainer, 100, trainingDs, trainingDs); Path dir = Paths.get("build/mlp"); Files.createDirectories(dir); model.save(dir, "predictorAndTrainer"); Model model2 = Model.newInstance("model"); model2.setBlock(mlp); model2.load(dir, "predictorAndTrainer"); Predictor<NDList, NDList> p2 = model2.newPredictor(new NoopTranslator()); NDManager manager2 = NDManager.newBaseManager(); NDList output = p2.predict(new NDList(manager2.ones(new Shape(1, 2)))); System.out.println(output.get(0)); }
The output is:
ND: (1, 1) cpu() float32 [[0.4958], ]
— Reply to this email directly, view it on GitHub https://github.com/deepjavalibrary/djl/issues/1663#issuecomment-1144400991, or unsubscribe https://github.com/notifications/unsubscribe-auth/AJKX36DJHQDYOQADB5U56BDVNAUJ5ANCNFSM5WIOWTGA . You are receiving this because you were mentioned.Message ID: @.***>
-- Language? Kotlin, Typescript or Rust? All of them!
@freemanliu
Since we only look 1 level of the directory, FileVisitOption.FOLLOW_LINK
should work here. Would you mind raise an PR to improve this?
Fixed by #1692
Description
This is reproduceable with the following test case: @Test fun testModelLoad() { var model = Model.newInstance("model") model.block = Mlp(2, 1, intArrayOf(10)) model.newTrainer(DefaultTrainingConfig(Loss.l2Loss())).use { trainer -> trainer.initialize(Shape(2)) val manager = model.ndManager; val input = manager.ones(Shape(1, 2), DataType.FLOAT32) val label = manager.create(floatArrayOf(0.5f)) val trainingDs = ArrayDataset.Builder().setData(input) .optLabels(label).setSampling(1, false).build() EasyTrain.fit(trainer, 100, trainingDs, trainingDs) model.save(Path.of("/tmp"), "predictorAndTrainer") } val model2 = Model.newInstance("model") model2.load(Path.of("/tmp"), "model") val p2 = model2.newPredictor(NoopTranslator()) NDManager.newBaseManager().use { manager -> println(p2.predict(NDList(manager.ones(Shape(1, 2))))) } }
Here is the gradle dependency to use pytorch engine. implementation 'ai.djl:basicdataset:0.17.0' implementation 'ai.djl:model-zoo:0.17.0' implementation 'ai.djl.pytorch:pytorch-model-zoo:0.17.0'
Further investigation shows that the save() is done in BaseModel while the load is done in PtModel. I was expecting the save() is also done in PtModel.
Expected Behavior
model.load succeeds.
Error Message
model.pt file not found in: /tmp java.io.FileNotFoundException: model.pt file not found in: /tmp at ai.djl.pytorch.engine.PtModel.load(PtModel.java:74) at ai.djl.Model.load(Model.java:121) at helloworld.jdl.AppTest.testModelLoad(AppTest.kt:79) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.junit.runners.model.FrameworkMethod$1.runReflectiveCall(FrameworkMethod.java:50) at org.junit.internal.runners.model.ReflectiveCallable.run(ReflectiveCallable.java:12) at org.junit.runners.model.FrameworkMethod.invokeExplosively(FrameworkMethod.java:47) at org.junit.internal.runners.statements.InvokeMethod.evaluate(InvokeMethod.java:17) at org.junit.runners.ParentRunner.runLeaf(ParentRunner.java:325) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:78) at org.junit.runners.BlockJUnit4ClassRunner.runChild(BlockJUnit4ClassRunner.java:57) at org.junit.runners.ParentRunner$3.run(ParentRunner.java:290) at org.junit.runners.ParentRunner$1.schedule(ParentRunner.java:71) at org.junit.runners.ParentRunner.runChildren(ParentRunner.java:288) at org.junit.runners.ParentRunner.access$000(ParentRunner.java:58) at org.junit.runners.ParentRunner$2.evaluate(ParentRunner.java:268) at org.junit.runners.ParentRunner.run(ParentRunner.java:363) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.runTestClass(JUnitTestClassExecutor.java:110) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:58) at org.gradle.api.internal.tasks.testing.junit.JUnitTestClassExecutor.execute(JUnitTestClassExecutor.java:38) at org.gradle.api.internal.tasks.testing.junit.AbstractJUnitTestClassProcessor.processTestClass(AbstractJUnitTestClassProcessor.java:62) at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.processTestClass(SuiteTestClassProcessor.java:51) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24) at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33) at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94) at com.sun.proxy.$Proxy2.processTestClass(Unknown Source) at org.gradle.api.internal.tasks.testing.worker.TestWorker$2.run(TestWorker.java:176) at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129) at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100) at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60) at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56) at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:133) at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:71) at worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69) at worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
How to Reproduce?
(If you developed your own code, please provide a short script that reproduces the error. For existing examples, please provide link.)
Steps to reproduce
(Paste the commands you ran that produced the error.)
What have you tried to solve it?
1. 2.
Environment Info
Please run the command
./gradlew debugEnv
from the root directory of DJL (if necessary, clone DJL first). It will output information about your system, environment, and installation that can help us debug your issue. Paste the output of the command below: