Closed Brucia323 closed 1 year ago
Please provide a test case where you think something is going wrong. That code looks ok to me.
p.nextToken()
can return null when you reach the end of the input.
The error I got when I called EasyTrain.fit()
from DJL, I'm not sure if it was jackson, but through debug mode I found the exception was thrown from jackson's code and may need assistance from DJL.
My dataset is converted from json to:
// The complete code is available at the end
TablesawDataset dataset =
TablesawDataset.builder()
.setReadOptions(JsonReadOptions.builderFromString(json).build())
.addNumericFeature("amount")
.addNumericLabel("date")
.setSampling(2, false)
.build();
This is part of the stack of the location of the thrown exception that I found through debug mode:
_loadMore:276, ReaderBasedJsonParser (com.fasterxml.jackson.core.json)
_skipWSOrEnd:2519, ReaderBasedJsonParser (com.fasterxml.jackson.core.json)
nextToken:698, ReaderBasedJsonParser (com.fasterxml.jackson.core.json)
_readTreeAndClose:4759, ObjectMapper (com.fasterxml.jackson.databind)
readTree:3113, ObjectMapper (com.fasterxml.jackson.databind)
read:35, JsonReader (tech.tablesaw.io.json)
read:16, JsonReader (tech.tablesaw.io.json)
usingOptions:152, DataFrameReader (tech.tablesaw.io)
prepare:52, TablesawDataset (ai.djl.tablesaw)
prepare:63, Dataset (ai.djl.training.dataset)
getData:129, RandomAccessDataset (ai.djl.training.dataset)
getData:98, RandomAccessDataset (ai.djl.training.dataset)
iterateDataset:152, Trainer (ai.djl.training)
fit:54, EasyTrain (ai.djl.training)
lstm:82, DeepJavaLibrary (io.zcy.todo.djl)
lstm:17, DeepJavaLibraryTest (io.zcy.todo.djl)
I'm using Spring Boot 3.0.6, and here are all the dependencies:
implementation("org.springframework.boot:spring-boot-starter-actuator")
implementation("org.springframework.boot:spring-boot-starter-data-r2dbc")
implementation("org.springframework.boot:spring-boot-starter-validation")
implementation("org.springframework.boot:spring-boot-starter-webflux")
compileOnly("org.projectlombok:lombok")
developmentOnly("org.springframework.boot:spring-boot-devtools")
runtimeOnly("org.postgresql:postgresql")
runtimeOnly("org.postgresql:r2dbc-postgresql")
annotationProcessor("org.springframework.boot:spring-boot-configuration-processor")
annotationProcessor("org.projectlombok:lombok")
testImplementation("org.springframework.boot:spring-boot-starter-test")
testImplementation("io.projectreactor:reactor-test")
implementation("org.springframework.security:spring-security-crypto")
implementation("org.slf4j:slf4j-api")
implementation("com.auth0:java-jwt:4.3.0")
implementation("ai.djl.spring:djl-spring-boot-starter-mxnet-auto:0.20")
implementation("ai.djl.spring:djl-spring-boot-starter-autoconfigure:0.20")
implementation("ai.djl.tablesaw:tablesaw:0.22.1")
implementation("tech.tablesaw:tablesaw-core:0.43.1")
implementation("tech.tablesaw:tablesaw-json:0.43.1")
implementation("com.fasterxml.jackson.datatype:jackson-datatype-jsr310:2.15.0")
Here is my code:
package io.zcy.todo.djl;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.tablesaw.TablesawDataset;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingResult;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.listener.SaveModelTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import com.fasterxml.jackson.databind.ObjectMapper;
import jakarta.annotation.Resource;
import java.io.IOException;
import java.util.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;
import tech.tablesaw.io.json.JsonReadOptions;
@Component
@Slf4j
public final class DeepJavaLibrary {
public TrainingResult lstm() throws IOException, TranslateException {
String json = "[{\"date\":\"2023-05-04\",\"amount\":30},{\"date\":\"2023-05-05\",\"amount\":40},{\"date\":\"2023-05-06\",\"amount\":90},{\"date\":\"2023-05-07\",\"amount\":80}]";
TablesawDataset dataset =
TablesawDataset.builder()
.setReadOptions(JsonReadOptions.builderFromString(json).build())
.addNumericFeature("amount")
.addNumericLabel("date")
.setSampling(2, false)
.build();
dataset.prepare(new ProgressBar());
try (Model model = Model.newInstance("lstm")) {
model.setBlock(getLSTMModel());
DefaultTrainingConfig config = setupTrainingConfig();
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Shape shape = new Shape(32, 1, dataset.size(), 2);
trainer.initialize(shape);
EasyTrain.fit(trainer, 1, dataset, dataset);
return trainer.getTrainingResult();
}
}
}
private static Block getLSTMModel() {
SequentialBlock block = new SequentialBlock();
block.addSingleton(
input -> {
Shape inputShape = input.getShape();
long batchSize = inputShape.get(0);
long channel = inputShape.get(3);
long time = inputShape.size() / (batchSize * channel);
return input.reshape(new Shape(batchSize, time, channel));
});
block.add(
new LSTM.Builder()
.setStateSize(64)
.setNumLayers(1)
.optDropRate(0)
.optReturnState(false)
.build());
block.add(BatchNorm.builder().optEpsilon(1e-5f).optMomentum(0.9f).build());
block.add(Blocks.batchFlattenBlock());
block.add(Linear.builder().setUnits(10).build());
return block;
}
private static DefaultTrainingConfig setupTrainingConfig() {
String outputDir = "/build/model";
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(1))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}
}
This is the error reported from the console:
java.io.IOException: Stream closed
tech.tablesaw.io.RuntimeIOException: java.io.IOException: Stream closed
at app//tech.tablesaw.io.json.JsonReader.read(JsonReader.java:37)
at app//tech.tablesaw.io.json.JsonReader.read(JsonReader.java:16)
at app//tech.tablesaw.io.DataFrameReader.usingOptions(DataFrameReader.java:152)
at app//ai.djl.tablesaw.TablesawDataset.prepare(TablesawDataset.java:52)
at app//ai.djl.training.dataset.Dataset.prepare(Dataset.java:63)
at app//ai.djl.training.dataset.RandomAccessDataset.getData(RandomAccessDataset.java:129)
at app//ai.djl.training.dataset.RandomAccessDataset.getData(RandomAccessDataset.java:98)
at app//ai.djl.training.Trainer.iterateDataset(Trainer.java:152)
at app//ai.djl.training.EasyTrain.fit(EasyTrain.java:54)
at app//io.zcy.todo.djl.DeepJavaLibrary.lstm(Unknown Source)
at app//io.zcy.todo.djl.DeepJavaLibraryTest.lstm(DeepJavaLibraryTest.java:17)
at java.base@17.0.6/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base@17.0.6/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base@17.0.6/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base@17.0.6/java.lang.reflect.Method.invoke(Method.java:568)
at app//org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:727)
at app//org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
at app//org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
at app//org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:156)
at app//org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestableMethod(TimeoutExtension.java:147)
at app//org.junit.jupiter.engine.extension.TimeoutExtension.interceptTestMethod(TimeoutExtension.java:86)
at app//org.junit.jupiter.engine.execution.InterceptingExecutableInvoker$ReflectiveInterceptorCall.lambda$ofVoidMethod$0(InterceptingExecutableInvoker.java:103)
at app//org.junit.jupiter.engine.execution.InterceptingExecutableInvoker.lambda$invoke$0(InterceptingExecutableInvoker.java:93)
at app//org.junit.jupiter.engine.execution.InvocationInterceptorChain$InterceptedInvocation.proceed(InvocationInterceptorChain.java:106)
at app//org.junit.jupiter.engine.execution.InvocationInterceptorChain.proceed(InvocationInterceptorChain.java:64)
at app//org.junit.jupiter.engine.execution.InvocationInterceptorChain.chainAndInvoke(InvocationInterceptorChain.java:45)
at app//org.junit.jupiter.engine.execution.InvocationInterceptorChain.invoke(InvocationInterceptorChain.java:37)
at app//org.junit.jupiter.engine.execution.InterceptingExecutableInvoker.invoke(InterceptingExecutableInvoker.java:92)
at app//org.junit.jupiter.engine.execution.InterceptingExecutableInvoker.invoke(InterceptingExecutableInvoker.java:86)
at app//org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.lambda$invokeTestMethod$7(TestMethodTestDescriptor.java:217)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.invokeTestMethod(TestMethodTestDescriptor.java:213)
at app//org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:138)
at app//org.junit.jupiter.engine.descriptor.TestMethodTestDescriptor.execute(TestMethodTestDescriptor.java:68)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:151)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at app//org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
at java.base@17.0.6/java.util.ArrayList.forEach(ArrayList.java:1511)
at app//org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at app//org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
at java.base@17.0.6/java.util.ArrayList.forEach(ArrayList.java:1511)
at app//org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.invokeAll(SameThreadHierarchicalTestExecutorService.java:41)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$6(NodeTestTask.java:155)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$8(NodeTestTask.java:141)
at app//org.junit.platform.engine.support.hierarchical.Node.around(Node.java:137)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.lambda$executeRecursively$9(NodeTestTask.java:139)
at app//org.junit.platform.engine.support.hierarchical.ThrowableCollector.execute(ThrowableCollector.java:73)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.executeRecursively(NodeTestTask.java:138)
at app//org.junit.platform.engine.support.hierarchical.NodeTestTask.execute(NodeTestTask.java:95)
at app//org.junit.platform.engine.support.hierarchical.SameThreadHierarchicalTestExecutorService.submit(SameThreadHierarchicalTestExecutorService.java:35)
at app//org.junit.platform.engine.support.hierarchical.HierarchicalTestExecutor.execute(HierarchicalTestExecutor.java:57)
at app//org.junit.platform.engine.support.hierarchical.HierarchicalTestEngine.execute(HierarchicalTestEngine.java:54)
at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:107)
at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:88)
at org.junit.platform.launcher.core.EngineExecutionOrchestrator.lambda$execute$0(EngineExecutionOrchestrator.java:54)
at org.junit.platform.launcher.core.EngineExecutionOrchestrator.withInterceptedStreams(EngineExecutionOrchestrator.java:67)
at org.junit.platform.launcher.core.EngineExecutionOrchestrator.execute(EngineExecutionOrchestrator.java:52)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:114)
at org.junit.platform.launcher.core.DefaultLauncher.execute(DefaultLauncher.java:86)
at org.junit.platform.launcher.core.DefaultLauncherSession$DelegatingLauncher.execute(DefaultLauncherSession.java:86)
at org.junit.platform.launcher.core.SessionPerRequestLauncher.execute(SessionPerRequestLauncher.java:53)
at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor$CollectAllTestClassesExecutor.processAllTestClasses(JUnitPlatformTestClassProcessor.java:99)
at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor$CollectAllTestClassesExecutor.access$000(JUnitPlatformTestClassProcessor.java:79)
at org.gradle.api.internal.tasks.testing.junitplatform.JUnitPlatformTestClassProcessor.stop(JUnitPlatformTestClassProcessor.java:75)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:62)
at java.base@17.0.6/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base@17.0.6/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base@17.0.6/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base@17.0.6/java.lang.reflect.Method.invoke(Method.java:568)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at jdk.proxy1/jdk.proxy1.$Proxy2.stop(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker$3.run(TestWorker.java:193)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.executeAndMaintainThreadName(TestWorker.java:129)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:100)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.execute(TestWorker.java:60)
at org.gradle.process.internal.worker.child.ActionExecutionWorker.execute(ActionExecutionWorker.java:56)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:113)
at org.gradle.process.internal.worker.child.SystemApplicationClassLoaderWorker.call(SystemApplicationClassLoaderWorker.java:65)
at app//worker.org.gradle.process.internal.worker.GradleWorkerMain.run(GradleWorkerMain.java:69)
at app//worker.org.gradle.process.internal.worker.GradleWorkerMain.main(GradleWorkerMain.java:74)
Suppressed: java.lang.NullPointerException: Cannot invoke "java.lang.Float.floatValue()" because the return value of "ai.djl.training.TrainingResult.getValidateEvaluation(String)" is null
at io.zcy.todo.djl.DeepJavaLibrary.lambda$setupTrainingConfig$6(DeepJavaLibrary.java:118)
at ai.djl.training.listener.SaveModelTrainingListener.saveModel(SaveModelTrainingListener.java:151)
at ai.djl.training.listener.SaveModelTrainingListener.onTrainingEnd(SaveModelTrainingListener.java:90)
at ai.djl.training.Trainer.lambda$close$2(Trainer.java:331)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1511)
at ai.djl.training.Trainer.notifyListeners(Trainer.java:285)
at ai.djl.training.Trainer.close(Trainer.java:331)
at io.zcy.todo.djl.DeepJavaLibrary.lstm(Unknown Source)
... 86 more
Caused by: java.io.IOException: Stream closed
at java.base/java.io.StringReader.ensureOpen(StringReader.java:57)
at java.base/java.io.StringReader.read(StringReader.java:97)
at com.fasterxml.jackson.core.json.ReaderBasedJsonParser._loadMore(ReaderBasedJsonParser.java:276)
at com.fasterxml.jackson.core.json.ReaderBasedJsonParser._skipWSOrEnd(ReaderBasedJsonParser.java:2519)
at com.fasterxml.jackson.core.json.ReaderBasedJsonParser.nextToken(ReaderBasedJsonParser.java:698)
at com.fasterxml.jackson.databind.ObjectMapper._readTreeAndClose(ObjectMapper.java:4759)
at com.fasterxml.jackson.databind.ObjectMapper.readTree(ObjectMapper.java:3113)
at tech.tablesaw.io.json.JsonReader.read(JsonReader.java:35)
... 95 more
Jackson team are not going to debug this for you. If you want to report a Jackson issue, we need a standalone test case.
This line doesn't even compile in Java because of the unescaped "
chars.
String json = "[{"date":"2023-05-04","amount":30},{"date":"2023-05-05","amount":40},{"date":"2023-05-06","amount":90},{"date":"2023-05-07","amount":80}]";
Sorry, I seem to understand something, this bug should not be Jackson's problem
https://github.com/FasterXML/jackson-databind/blob/67103c2881aa506ebceb25824becac5b80a4f86a/src/main/java/com/fasterxml/jackson/databind/ObjectMapper.java#L4854