Closed androuino closed 2 years ago
@androuino If I understand correctly, you have multiple boundingbox in a single image:
NDList l = new NDList(label.reshape(new Shape(1).addAll(label.getShape())));
In your code, the number of lables
= number images * boundingbox
Here is what in your dataset:
1st record -> 1st image: 1st boundingbox
2nd record -> 2nd image: 2nd boundingbox in the 1st image
Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156
Thanks @frankfliu for the response. However, running the CocoTest gives me this error:
java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:226)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
at com.google.gson.internal.bind.TypeAdapterRuntimeTypeWrapper.read(TypeAdapterRuntimeTypeWrapper.java:41)
at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:82)
at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:61)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
at com.google.gson.Gson.fromJson(Gson.java:932)
at com.google.gson.Gson.fromJson(Gson.java:870)
at ai.djl.basicdataset.CocoUtils.prepare(CocoUtils.java:56)
at ai.djl.basicdataset.CocoDetection.prepare(CocoDetection.java:109)
at ai.djl.training.dataset.Dataset.prepare(Dataset.java:40)
at ai.djl.training.dataset.RandomAccessDataset.getData(RandomAccessDataset.java:83)
at ai.djl.training.Trainer.iterateDataset(Trainer.java:130)
at ai.djl.basicdataset.CocoTest.testCocoRemote(CocoTest.java:40)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
at org.testng.TestRunner.privateRun(TestRunner.java:766)
at org.testng.TestRunner.run(TestRunner.java:587)
at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
at org.testng.SuiteRunner.run(SuiteRunner.java:286)
at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
at org.testng.TestNG.runSuites(TestNG.java:1039)
at org.testng.TestNG.run(TestNG.java:1007)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at com.sun.proxy.$Proxy5.stop(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
at com.google.gson.stream.JsonReader.beginObject(JsonReader.java:386)
at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:215)
... 68 more
I did not change any of the code inside CocoDetection class except for downloading the coco dataset to run the test.
Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156
I have checked the actual CocoDetection class and modified my code accordingly which is a bit similar to CocoDetection:
usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
Type mapType = new TypeToken<Map<String, List<String[]>>>() {}.getType();
Map<String, List<String[]>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
for (Map.Entry<String, List<String[]>> entry : metadata.entrySet()) {
List<double[]> labelOfImage = getLabels(entry.getValue());
if (!labelOfImage.isEmpty()) {
imagePaths.add(usagePath.resolve(entry.getKey()));
labels.add(labelOfImage.toArray(new double[0][]));
}
}
}
then the function that has a similarity with CocoDetection class.
private double[] convertRecToList(String[] anno) {
double[] list = new double[5];
list[1] = Double.parseDouble(anno[1]);
list[2] = Double.parseDouble(anno[2]);
list[3] = Double.parseDouble(anno[3]);
list[4] = Double.parseDouble(anno[4]);
return list;
}
private List<double[]> getLabels(List<String[]> arr) {
List<double[]> label = new ArrayList<>();
for (String[] item : arr) {
double[] list = convertRecToList(item);
// add the category label
// map the original one to incremental index
list[0] = Double.parseDouble(item[0]);
label.add(list);
}
return label;
}
But am getting this error:
MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
File "../include/mxnet/operator.h", line 228
ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
File "../include/mxnet/operator.h", line 228
at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502)
at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91)
at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75)
at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288)
at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934)
at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74)
at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53)
at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66)
at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82)
at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77)
at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
at org.testng.TestRunner.privateRun(TestRunner.java:766)
at org.testng.TestRunner.run(TestRunner.java:587)
at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
at org.testng.SuiteRunner.run(SuiteRunner.java:286)
at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
at org.testng.TestNG.runSuites(TestNG.java:1039)
at org.testng.TestNG.run(TestNG.java:1007)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at com.sun.proxy.$Proxy5.stop(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
at java.base/java.lang.Thread.run(Thread.java:834)
Sorry, I may not understand what you are suggesting from your last comment but please help me get through with this error. Thank you in advance.
@androuino I confirmed CocoTest is failing, we will take a look and fix CocoDetection dataset bug.
Training ssd with Coco is not as straight-forward as Pikachu, we will trying to create an example.
Thanks for the response @frankfliu. I confirmed that the CocoDetection has been fixed. However, have you taken a look at the error I am getting?
MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
File "../include/mxnet/operator.h", line 228
So as of the moment, training a custom dataset with multiple bbox and classes annotated in one image doesn't simply work or supported yet following the Pikachu example?
@androuino it looks like you use the operator that doesn't support float64 data type. Do you know where you use float64?
Hi @stu1130, this is the class that I made which is kind of a combination of PikachuDetection class and the CocoDetection class: https://gist.github.com/androuino/00095ba5be3d10cab765bd2447d236cf Then this is my TrainCustomModel class: https://gist.github.com/androuino/6e8e014e3e70b107f4ebb9cf6c9387ac which I don't see or notice that I am using float64. My dataset structure if pretty much similar to the Pikachu dataset as well as the annotation values.
{"IMG_5237.jpg":[["0","0.085625","0.07473776223776224","0.043750000000000004","0.030594405594405596"],
["0","0.35125","0.28496503496503495","0.055","0.04020979020979021"],
["1","0.26875","0.3618881118881119","0.0575","0.033216783216783216"],
["0","0.7975","0.4602272727272727","0.0575","0.03583916083916084"],
["0","0.08750000000000001","0.6831293706293706","0.065","0.039335664335664336"],
["2","0.114375","0.5607517482517482","0.07875","0.028846153846153848"],
["2","0.115","0.40384615384615385","0.085","0.027972027972027972"],
["2","0.1125","0.3395979020979021","0.08","0.028846153846153848"],
["2","0.11375","0.23251748251748253","0.085","0.02972027972027972"],
["2","0.106875","0.04020979020979021","0.08125","0.026223776223776224"]], ...}
@androuino Does you have complete error stack trace? The label of the Pikachu dataset is float, but the label in CocoDetection is double. I am suspecting this is where you use float64, could you cast the label to float32 by toType
method?
@stu1130 So far this is the only stack trace that I could get:
MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1 Stack trace: File "../include/mxnet/operator.h", line 228 ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1 Stack trace: File "../include/mxnet/operator.h", line 228 at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808) at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502) at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91) at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75) at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288) at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934) at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74) at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53) at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66) at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82) at ai.djl.training.EasyTrain.fit(EasyTrain.java:45) at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77) at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134) at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597) at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173) at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46) at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816) at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146) at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146) at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128) at java.base/java.util.ArrayList.forEach(ArrayList.java:1540) at org.testng.TestRunner.privateRun(TestRunner.java:766) at org.testng.TestRunner.run(TestRunner.java:587) at org.testng.SuiteRunner.runTest(SuiteRunner.java:384) at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378) at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337) at org.testng.SuiteRunner.run(SuiteRunner.java:286) at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53) at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96) at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187) at org.testng.TestNG.runSuitesLocally(TestNG.java:1109) at org.testng.TestNG.runSuites(TestNG.java:1039) at org.testng.TestNG.run(TestNG.java:1007) at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141) at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90) at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24) at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33) at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94) at com.sun.proxy.$Proxy5.stop(Unknown Source) at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.base/java.lang.reflect.Method.invoke(Method.java:566) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36) at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24) at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182) at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164) at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414) at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64) at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56) at java.base/java.lang.Thread.run(Thread.java:834)
I will try to cast the label to float32 and get back to you. Thanks.
@stu1130 I tried your suggestion and this is the error am getting: This is the function looks like after I cast the label to float:
private float[] convertRecToFloatList(String[] anno) {
float[] list = new float[5];
list[1] = Float.parseFloat(anno[1]);
list[2] = Float.parseFloat(anno[2]);
list[3] = Float.parseFloat(anno[3]);
list[4] = Float.parseFloat(anno[4]);
return list;
}
private List<float[]> getLabelsAsFloat(List<String[]> arr) {
List<float[]> label = new ArrayList<>();
for (String[] item : arr) {
float[] box = convertRecToFloatList(item);
logger.info(Arrays.toString(box));
// add the category label
// map the original one to incremental index
float[] list = new float[5];
System.arraycopy(box, 1, list, 1, 4);
list[0] = Float.parseFloat(item[0]);
label.add(list);
}
return label;
}
Stack trace:
XNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
File "../include/mxnet/./tensor_blob.h", line 311
ai.djl.engine.EngineException: MXNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
File "../include/mxnet/./tensor_blob.h", line 311
at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
at ai.djl.mxnet.jna.JnaUtils.syncCopyToCPU(JnaUtils.java:475)
at ai.djl.mxnet.engine.MxNDArray.toByteBuffer(MxNDArray.java:280)
at ai.djl.ndarray.NDArray.toLongArray(NDArray.java:300)
at ai.djl.ndarray.NDArray.getLong(NDArray.java:558)
at ai.djl.training.evaluator.AbstractAccuracy.lambda$updateAccumulator$1(AbstractAccuracy.java:85)
at java.base/java.util.concurrent.ConcurrentHashMap.compute(ConcurrentHashMap.java:1932)
at ai.djl.training.evaluator.AbstractAccuracy.updateAccumulator(AbstractAccuracy.java:85)
at ai.djl.training.listener.EvaluatorTrainingListener.updateEvaluators(EvaluatorTrainingListener.java:147)
at ai.djl.training.listener.EvaluatorTrainingListener.onTrainingBatch(EvaluatorTrainingListener.java:114)
at ai.djl.training.EasyTrain.lambda$trainBatch$1(EasyTrain.java:92)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
at ai.djl.training.Trainer.notifyListeners(Trainer.java:263)
at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:92)
at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:81)
at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
at org.testng.TestRunner.privateRun(TestRunner.java:766)
at org.testng.TestRunner.run(TestRunner.java:587)
at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
at org.testng.SuiteRunner.run(SuiteRunner.java:286)
at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
at org.testng.TestNG.runSuites(TestNG.java:1039)
at org.testng.TestNG.run(TestNG.java:1007)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
at com.sun.proxy.$Proxy5.stop(Unknown Source)
at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:566)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
at java.base/java.lang.Thread.run(Thread.java:834)
@stu1130 I've some update about the training attempt when I changed the argument's value like this:
args = new String[] {"-e", "8", "-m", "1", "-b", "1"};
It runs the training, however, the training and validation's boundingBoxError doesn't seem improving. These are my throughout training for 8 epochs:
[INFO ] - Load MXNet Engine Version 1.7.0 in 0.227 ms.
[INFO ] - Epoch 1 finished.
[INFO ] - Train: classAccuracy: 0.33, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 1.68
[INFO ] - Validate: classAccuracy: 0.78, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Epoch 2 finished.
[INFO ] - Train: classAccuracy: 0.61, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Validate: classAccuracy: 0.87, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.81
[INFO ] - Epoch 3 finished.
[INFO ] - Train: classAccuracy: 0.86, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.50
[INFO ] - Validate: classAccuracy: 0.93, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.77
[INFO ] - Epoch 4 finished.
[INFO ] - Train: classAccuracy: 0.94, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.29
[INFO ] - Validate: classAccuracy: 0.96, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.72
[INFO ] - Epoch 5 finished.
[INFO ] - Train: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.17
[INFO ] - Validate: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.66
[INFO ] - Epoch 6 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.09
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.58
[INFO ] - Epoch 7 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.06
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.49
[INFO ] - Epoch 8 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.03
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.38
[INFO ] - forward P50: 849.174 ms, P90: 878.440 ms
[INFO ] - training-metrics P50: 0.005 ms, P90: 0.034 ms
[INFO ] - backward P50: 6.235 ms, P90: 11.146 ms
[INFO ] - step P50: 25.011 ms, P90: 52.742 ms
[INFO ] - epoch P50: 6.627 s, P90: 8.692 s
Should I raise a concern or not about the boundingBoxError, is it normal or there's something that is not right? Thanks for the help.
update: I tried to test the trained model but it doesn't detect any object.
Looks like it's working now with this arguments settings: args = new String[] {"-e", "8", "-b", "1"};
.
"-m 1" a.k.a "max-batches" means we train the model with only 1 max-batches for each epoch, which is usually for sanity test.
I see, thanks for the info @stu1130. However, how could determine if its an ideal time to stop the training? What should I look for at the values during training? Thanks.
@stu1130, could you also explain to me the Pikachu's index.file values?
"img_0.jpg": [4.0, 5.0, 512.0, 512.0, 0.0, 0.604744553565979, 0.40195202827453613, 0.6948338747024536, 0.5354305505752563],
I suppose that 512.0 is the image's height and width. How about the 4.0, and 5.0? What are these values? Thanks.
@androuino The pikachu example is a simplified version, it doesn't require image augmentations since it only has one class. However, training a proper ssd model require many image augmentations when loading Record from dataset, you can find python code here: https://github.com/dmlc/gluon-cv/blob/master/gluoncv/data/transforms/presets/ssd.py#L98
The full python training code can be found: https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/ssd/train_ssd.py
Closing due to inactivity. Please reopen if you still require help.
Question
Hi, my question is... While I have successfully executed the training for obj. detection following the TrainPikachu class but I am getting this result from the console:
As you will notice, the classAccuracy and boundingBoxError are empty or has no value at all.
Then I train the model for 8 epochs and tried to test it but am getting no detection.
I have 3 classes to identify and my Shape is:
Shape inputShape = new Shape(arguments.getBatchSize(), 3, 800, 1144);
My index.file is like this:
So meaning I have multiple bounding boxes annotated in a single image. the first index is the class name then the second to last are the bounding boxes.
On getting the Record, I am using the same method as with TrainPikachu:
Then this is how I prepare the dataset that is close to the PikachuDetection class:
Please let me know if I am doing something that is very different from the TrainPikachu example of why my trained model doesn't detect any object. Thank you in advance.