deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 658 forks source link

Training Obj. Detection with custom datasets #190

Closed androuino closed 2 years ago

androuino commented 4 years ago

Question

Hi, my question is... While I have successfully executed the training for obj. detection following the TrainPikachu class but I am getting this result from the console: Screen Shot 2020-10-05 at 11 04 16

As you will notice, the classAccuracy and boundingBoxError are empty or has no value at all.

Then I train the model for 8 epochs and tried to test it but am getting no detection.

I have 3 classes to identify and my Shape is: Shape inputShape = new Shape(arguments.getBatchSize(), 3, 800, 1144);

My index.file is like this:

{"IMG_5237.jpg":[["0","0.085625","0.07473776223776224","0.043750000000000004","0.030594405594405596"],
["0","0.35125","0.28496503496503495","0.055","0.04020979020979021"],
["1","0.26875","0.3618881118881119","0.0575","0.033216783216783216"],
["0","0.7975","0.4602272727272727","0.0575","0.03583916083916084"],
["0","0.08750000000000001","0.6831293706293706","0.065","0.039335664335664336"],
["2","0.114375","0.5607517482517482","0.07875","0.028846153846153848"],
["2","0.115","0.40384615384615385","0.085","0.027972027972027972"],
["2","0.1125","0.3395979020979021","0.08","0.028846153846153848"],
["2","0.11375","0.23251748251748253","0.085","0.02972027972027972"],
["2","0.106875","0.04020979020979021","0.08125","0.026223776223776224"]], ...}

So meaning I have multiple bounding boxes annotated in a single image. the first index is the class name then the second to last are the bounding boxes.

On getting the Record, I am using the same method as with TrainPikachu:

@Override
protected Record get(NDManager manager, long index) throws IOException {
    int idx = Math.toIntExact(index);
    NDList d = new NDList(ImageFactory.getInstance()
            .fromFile(imagePaths.get(idx))
            .toNDArray(manager, flag));
    NDArray label = manager.create(labels.get(idx));
    NDList l = new NDList(label.reshape(new Shape(1).addAll(label.getShape())));
    return new Record(d, l);
}

Then this is how I prepare the dataset that is close to the PikachuDetection class:

try (Reader reader = Files.newBufferedReader(indexFile)) {
    Type mapType = new TypeToken<Map<String, List<String[]>>>() {}.getType();
    Map<String, List<String[]>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
    for (Map.Entry<String, List<String[]>> entry : metadata.entrySet()) {
        String imgName = entry.getKey();
        for (String[] item : entry.getValue()) {
            float[] labelArray = new float[5];
            // Class label
            labelArray[0] = Float.parseFloat(item[0]);

            // Bounding box labels
            labelArray[1] = Float.parseFloat(item[1]);
            labelArray[2] = Float.parseFloat(item[2]);
            labelArray[3] = Float.parseFloat(item[3]);
            labelArray[4] = Float.parseFloat(item[4]);
            labels.add(labelArray);
        }
        imagePaths.add(usagePath.resolve(imgName));
    }
}

Please let me know if I am doing something that is very different from the TrainPikachu example of why my trained model doesn't detect any object. Thank you in advance.

frankfliu commented 4 years ago

@androuino If I understand correctly, you have multiple boundingbox in a single image:

In your code, the number of lables = number images * boundingbox Here is what in your dataset: 1st record -> 1st image: 1st boundingbox 2nd record -> 2nd image: 2nd boundingbox in the 1st image

Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156

androuino commented 4 years ago

Thanks @frankfliu for the response. However, running the CocoTest gives me this error:

java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
com.google.gson.JsonSyntaxException: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:226)
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
    at com.google.gson.internal.bind.TypeAdapterRuntimeTypeWrapper.read(TypeAdapterRuntimeTypeWrapper.java:41)
    at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:82)
    at com.google.gson.internal.bind.CollectionTypeAdapterFactory$Adapter.read(CollectionTypeAdapterFactory.java:61)
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$1.read(ReflectiveTypeAdapterFactory.java:131)
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:222)
    at com.google.gson.Gson.fromJson(Gson.java:932)
    at com.google.gson.Gson.fromJson(Gson.java:870)
    at ai.djl.basicdataset.CocoUtils.prepare(CocoUtils.java:56)
    at ai.djl.basicdataset.CocoDetection.prepare(CocoDetection.java:109)
    at ai.djl.training.dataset.Dataset.prepare(Dataset.java:40)
    at ai.djl.training.dataset.RandomAccessDataset.getData(RandomAccessDataset.java:83)
    at ai.djl.training.Trainer.iterateDataset(Trainer.java:130)
    at ai.djl.basicdataset.CocoTest.testCocoRemote(CocoTest.java:40)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
    at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
    at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
    at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
    at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
    at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
    at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
    at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
    at org.testng.TestRunner.privateRun(TestRunner.java:766)
    at org.testng.TestRunner.run(TestRunner.java:587)
    at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
    at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
    at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
    at org.testng.SuiteRunner.run(SuiteRunner.java:286)
    at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
    at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
    at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
    at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
    at org.testng.TestNG.runSuites(TestNG.java:1039)
    at org.testng.TestNG.run(TestNG.java:1007)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
    at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
    at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
    at com.sun.proxy.$Proxy5.stop(Unknown Source)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
    at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
    at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
    at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
    at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
    at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.IllegalStateException: Expected BEGIN_OBJECT but was BEGIN_ARRAY at line 1 column 1369436 path $.annotations[0].bbox
    at com.google.gson.stream.JsonReader.beginObject(JsonReader.java:386)
    at com.google.gson.internal.bind.ReflectiveTypeAdapterFactory$Adapter.read(ReflectiveTypeAdapterFactory.java:215)
    ... 68 more

I did not change any of the code inside CocoDetection class except for downloading the coco dataset to run the test.

androuino commented 4 years ago

Here is the code in CocoDetection dataset which is similar to what you need: https://github.com/awslabs/djl/blob/master/basicdataset/src/main/java/ai/djl/basicdataset/CocoDetection.java#L137-L156

I have checked the actual CocoDetection class and modified my code accordingly which is a bit similar to CocoDetection:

usagePath = root.resolve(usagePath);
Path indexFile = usagePath.resolve("index.file");
try (Reader reader = Files.newBufferedReader(indexFile)) {
    Type mapType = new TypeToken<Map<String, List<String[]>>>() {}.getType();
    Map<String, List<String[]>> metadata = JsonUtils.GSON.fromJson(reader, mapType);
    for (Map.Entry<String, List<String[]>> entry : metadata.entrySet()) {
        List<double[]> labelOfImage = getLabels(entry.getValue());
        if (!labelOfImage.isEmpty()) {
            imagePaths.add(usagePath.resolve(entry.getKey()));
            labels.add(labelOfImage.toArray(new double[0][]));
        }
    }
}

then the function that has a similarity with CocoDetection class.

private double[] convertRecToList(String[] anno) {
    double[] list = new double[5];
    list[1] = Double.parseDouble(anno[1]);
    list[2] = Double.parseDouble(anno[2]);
    list[3] = Double.parseDouble(anno[3]);
    list[4] = Double.parseDouble(anno[4]);
    return list;
}

private List<double[]> getLabels(List<String[]> arr) {
    List<double[]> label = new ArrayList<>();
    for (String[] item : arr) {
        double[] list = convertRecToList(item);
        // add the category label
        // map the original one to incremental index
        list[0] = Double.parseDouble(item[0]);
        label.add(list);
    }
    return label;
}

But am getting this error:

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

    at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
    at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502)
    at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91)
    at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75)
    at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288)
    at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934)
    at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74)
    at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53)
    at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66)
    at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
    at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77)
    at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
    at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
    at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
    at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
    at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
    at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
    at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
    at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
    at org.testng.TestRunner.privateRun(TestRunner.java:766)
    at org.testng.TestRunner.run(TestRunner.java:587)
    at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
    at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
    at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
    at org.testng.SuiteRunner.run(SuiteRunner.java:286)
    at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
    at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
    at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
    at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
    at org.testng.TestNG.runSuites(TestNG.java:1039)
    at org.testng.TestNG.run(TestNG.java:1007)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
    at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
    at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
    at com.sun.proxy.$Proxy5.stop(Unknown Source)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
    at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
    at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
    at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
    at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
    at java.base/java.lang.Thread.run(Thread.java:834)

Sorry, I may not understand what you are suggesting from your last comment but please help me get through with this error. Thank you in advance.

frankfliu commented 4 years ago

@androuino I confirmed CocoTest is failing, we will take a look and fix CocoDetection dataset bug.

Training ssd with Coco is not as straight-forward as Pikachu, we will trying to create an example.

androuino commented 4 years ago

Thanks for the response @frankfliu. I confirmed that the CocoDetection has been fixed. However, have you taken a look at the error I am getting?

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

So as of the moment, training a custom dataset with multiple bbox and classes annotated in one image doesn't simply work or supported yet following the Pikachu example?

stu1130 commented 4 years ago

@androuino it looks like you use the operator that doesn't support float64 data type. Do you know where you use float64?

androuino commented 4 years ago

Hi @stu1130, this is the class that I made which is kind of a combination of PikachuDetection class and the CocoDetection class: https://gist.github.com/androuino/00095ba5be3d10cab765bd2447d236cf Then this is my TrainCustomModel class: https://gist.github.com/androuino/6e8e014e3e70b107f4ebb9cf6c9387ac which I don't see or notice that I am using float64. My dataset structure if pretty much similar to the Pikachu dataset as well as the annotation values.

{"IMG_5237.jpg":[["0","0.085625","0.07473776223776224","0.043750000000000004","0.030594405594405596"],
["0","0.35125","0.28496503496503495","0.055","0.04020979020979021"],
["1","0.26875","0.3618881118881119","0.0575","0.033216783216783216"],
["0","0.7975","0.4602272727272727","0.0575","0.03583916083916084"],
["0","0.08750000000000001","0.6831293706293706","0.065","0.039335664335664336"],
["2","0.114375","0.5607517482517482","0.07875","0.028846153846153848"],
["2","0.115","0.40384615384615385","0.085","0.027972027972027972"],
["2","0.1125","0.3395979020979021","0.08","0.028846153846153848"],
["2","0.11375","0.23251748251748253","0.085","0.02972027972027972"],
["2","0.106875","0.04020979020979021","0.08125","0.026223776223776224"]], ...}
stu1130 commented 4 years ago

@androuino Does you have complete error stack trace? The label of the Pikachu dataset is float, but the label in CocoDetection is double. I am suspecting this is where you use float64, could you cast the label to float32 by toType method?

androuino commented 4 years ago

@stu1130 So far this is the only stack trace that I could get:

MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

ai.djl.engine.EngineException: MXNet engine call failed: MXNetError: Check failed: in_type->at(i) == mshadow: :default_type_flag || in_type->at(i) == -1: Unsupported data type 1
Stack trace:
  File "../include/mxnet/operator.h", line 228

  at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
  at ai.djl.mxnet.jna.JnaUtils.imperativeInvoke(JnaUtils.java:502)
  at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:91)
  at ai.djl.mxnet.jna.FunctionInfo.invoke(FunctionInfo.java:75)
  at ai.djl.mxnet.engine.MxNDManager.invoke(MxNDManager.java:288)
  at ai.djl.mxnet.engine.MxNDArrayEx.multiBoxTarget(MxNDArrayEx.java:934)
  at ai.djl.modality.cv.MultiBoxTarget.target(MultiBoxTarget.java:74)
  at ai.djl.training.loss.SingleShotDetectionLoss.inputForComponent(SingleShotDetectionLoss.java:53)
  at ai.djl.training.loss.AbstractCompositeLoss.evaluate(AbstractCompositeLoss.java:66)
  at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:82)
  at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
  at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:77)
  at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  at java.base/java.lang.reflect.Method.invoke(Method.java:566)
  at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
  at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
  at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
  at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
  at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
  at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
  at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
  at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
  at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
  at org.testng.TestRunner.privateRun(TestRunner.java:766)
  at org.testng.TestRunner.run(TestRunner.java:587)
  at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
  at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
  at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
  at org.testng.SuiteRunner.run(SuiteRunner.java:286)
  at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
  at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
  at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
  at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
  at org.testng.TestNG.runSuites(TestNG.java:1039)
  at org.testng.TestNG.run(TestNG.java:1007)
  at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
  at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
  at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  at java.base/java.lang.reflect.Method.invoke(Method.java:566)
  at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
  at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
  at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
  at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
  at com.sun.proxy.$Proxy5.stop(Unknown Source)
  at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  at java.base/java.lang.reflect.Method.invoke(Method.java:566)
  at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
  at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
  at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
  at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
  at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
  at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
  at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
  at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
  at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
  at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
  at java.base/java.lang.Thread.run(Thread.java:834)

I will try to cast the label to float32 and get back to you. Thanks.

androuino commented 4 years ago

@stu1130 I tried your suggestion and this is the error am getting: This is the function looks like after I cast the label to float:

private float[] convertRecToFloatList(String[] anno) {
    float[] list = new float[5];
    list[1] = Float.parseFloat(anno[1]);
    list[2] = Float.parseFloat(anno[2]);
    list[3] = Float.parseFloat(anno[3]);
    list[4] = Float.parseFloat(anno[4]);
    return list;
}

private List<float[]> getLabelsAsFloat(List<String[]> arr) {
    List<float[]> label = new ArrayList<>();
    for (String[] item : arr) {
        float[] box = convertRecToFloatList(item);
        logger.info(Arrays.toString(box));
        // add the category label
        // map the original one to incremental index
        float[] list = new float[5];
        System.arraycopy(box, 1, list, 1, 4);
        list[0] = Float.parseFloat(item[0]);
        label.add(list);
    }
    return label;
}

Stack trace:

XNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
  File "../include/mxnet/./tensor_blob.h", line 311

ai.djl.engine.EngineException: MXNet engine call failed: TBlob.get_with_shape: Check failed: this->shape_.Size() == static_cast<size_t>(shape.Size()) (65 vs. 285) : new and old shape do not match total elements
Stack trace:
  File "../include/mxnet/./tensor_blob.h", line 311

    at ai.djl.mxnet.jna.JnaUtils.checkCall(JnaUtils.java:1808)
    at ai.djl.mxnet.jna.JnaUtils.syncCopyToCPU(JnaUtils.java:475)
    at ai.djl.mxnet.engine.MxNDArray.toByteBuffer(MxNDArray.java:280)
    at ai.djl.ndarray.NDArray.toLongArray(NDArray.java:300)
    at ai.djl.ndarray.NDArray.getLong(NDArray.java:558)
    at ai.djl.training.evaluator.AbstractAccuracy.lambda$updateAccumulator$1(AbstractAccuracy.java:85)
    at java.base/java.util.concurrent.ConcurrentHashMap.compute(ConcurrentHashMap.java:1932)
    at ai.djl.training.evaluator.AbstractAccuracy.updateAccumulator(AbstractAccuracy.java:85)
    at ai.djl.training.listener.EvaluatorTrainingListener.updateEvaluators(EvaluatorTrainingListener.java:147)
    at ai.djl.training.listener.EvaluatorTrainingListener.onTrainingBatch(EvaluatorTrainingListener.java:114)
    at ai.djl.training.EasyTrain.lambda$trainBatch$1(EasyTrain.java:92)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
    at ai.djl.training.Trainer.notifyListeners(Trainer.java:263)
    at ai.djl.training.EasyTrain.trainBatch(EasyTrain.java:92)
    at ai.djl.training.EasyTrain.fit(EasyTrain.java:45)
    at ai.djl.examples.training.TrainCustomModel.runTraining(TrainCustomModel.java:81)
    at ai.djl.examples.training.TrainCustomModelTest.testTrainingCustomModel(TrainCustomModelTest.java:31)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.testng.internal.MethodInvocationHelper.invokeMethod(MethodInvocationHelper.java:134)
    at org.testng.internal.TestInvoker.invokeMethod(TestInvoker.java:597)
    at org.testng.internal.TestInvoker.invokeTestMethod(TestInvoker.java:173)
    at org.testng.internal.MethodRunner.runInSequence(MethodRunner.java:46)
    at org.testng.internal.TestInvoker$MethodInvocationAgent.invoke(TestInvoker.java:816)
    at org.testng.internal.TestInvoker.invokeTestMethods(TestInvoker.java:146)
    at org.testng.internal.TestMethodWorker.invokeTestMethods(TestMethodWorker.java:146)
    at org.testng.internal.TestMethodWorker.run(TestMethodWorker.java:128)
    at java.base/java.util.ArrayList.forEach(ArrayList.java:1540)
    at org.testng.TestRunner.privateRun(TestRunner.java:766)
    at org.testng.TestRunner.run(TestRunner.java:587)
    at org.testng.SuiteRunner.runTest(SuiteRunner.java:384)
    at org.testng.SuiteRunner.runSequentially(SuiteRunner.java:378)
    at org.testng.SuiteRunner.privateRun(SuiteRunner.java:337)
    at org.testng.SuiteRunner.run(SuiteRunner.java:286)
    at org.testng.SuiteRunnerWorker.runSuite(SuiteRunnerWorker.java:53)
    at org.testng.SuiteRunnerWorker.run(SuiteRunnerWorker.java:96)
    at org.testng.TestNG.runSuitesSequentially(TestNG.java:1187)
    at org.testng.TestNG.runSuitesLocally(TestNG.java:1109)
    at org.testng.TestNG.runSuites(TestNG.java:1039)
    at org.testng.TestNG.run(TestNG.java:1007)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.runTests(TestNGTestClassProcessor.java:141)
    at org.gradle.api.internal.tasks.testing.testng.TestNGTestClassProcessor.stop(TestNGTestClassProcessor.java:90)
    at org.gradle.api.internal.tasks.testing.SuiteTestClassProcessor.stop(SuiteTestClassProcessor.java:61)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.dispatch.ContextClassLoaderDispatch.dispatch(ContextClassLoaderDispatch.java:33)
    at org.gradle.internal.dispatch.ProxyDispatchAdapter$DispatchingInvocationHandler.invoke(ProxyDispatchAdapter.java:94)
    at com.sun.proxy.$Proxy5.stop(Unknown Source)
    at org.gradle.api.internal.tasks.testing.worker.TestWorker.stop(TestWorker.java:133)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
    at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:36)
    at org.gradle.internal.dispatch.ReflectionDispatch.dispatch(ReflectionDispatch.java:24)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:182)
    at org.gradle.internal.remote.internal.hub.MessageHubBackedObjectConnection$DispatchWrapper.dispatch(MessageHubBackedObjectConnection.java:164)
    at org.gradle.internal.remote.internal.hub.MessageHub$Handler.run(MessageHub.java:414)
    at org.gradle.internal.concurrent.ExecutorPolicy$CatchAndRecordFailures.onExecute(ExecutorPolicy.java:64)
    at org.gradle.internal.concurrent.ManagedExecutorImpl$1.run(ManagedExecutorImpl.java:48)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
    at org.gradle.internal.concurrent.ThreadFactoryImpl$ManagedThreadRunnable.run(ThreadFactoryImpl.java:56)
    at java.base/java.lang.Thread.run(Thread.java:834)
androuino commented 4 years ago

@stu1130 I've some update about the training attempt when I changed the argument's value like this:

args = new String[] {"-e", "8", "-m", "1", "-b", "1"};

It runs the training, however, the training and validation's boundingBoxError doesn't seem improving. These are my throughout training for 8 epochs:

[INFO ] - Load MXNet Engine Version 1.7.0 in 0.227 ms.
[INFO ] - Epoch 1 finished.
[INFO ] - Train: classAccuracy: 0.33, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 1.68
[INFO ] - Validate: classAccuracy: 0.78, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Epoch 2 finished.
[INFO ] - Train: classAccuracy: 0.61, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.90
[INFO ] - Validate: classAccuracy: 0.87, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.81
[INFO ] - Epoch 3 finished.
[INFO ] - Train: classAccuracy: 0.86, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.50
[INFO ] - Validate: classAccuracy: 0.93, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.77
[INFO ] - Epoch 4 finished.
[INFO ] - Train: classAccuracy: 0.94, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.29
[INFO ] - Validate: classAccuracy: 0.96, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.72
[INFO ] - Epoch 5 finished.
[INFO ] - Train: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.17
[INFO ] - Validate: classAccuracy: 0.99, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.66
[INFO ] - Epoch 6 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.09
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.58
[INFO ] - Epoch 7 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.06
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.49
[INFO ] - Epoch 8 finished.
[INFO ] - Train: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.03
[INFO ] - Validate: classAccuracy: 1.00, boundingBoxError: 0.00E+00, SingleShotDetectionLoss: 0.38
[INFO ] - forward P50: 849.174 ms, P90: 878.440 ms
[INFO ] - training-metrics P50: 0.005 ms, P90: 0.034 ms
[INFO ] - backward P50: 6.235 ms, P90: 11.146 ms
[INFO ] - step P50: 25.011 ms, P90: 52.742 ms
[INFO ] - epoch P50: 6.627 s, P90: 8.692 s

Should I raise a concern or not about the boundingBoxError, is it normal or there's something that is not right? Thanks for the help.

update: I tried to test the trained model but it doesn't detect any object.

androuino commented 4 years ago

Looks like it's working now with this arguments settings: args = new String[] {"-e", "8", "-b", "1"};.

stu1130 commented 4 years ago

"-m 1" a.k.a "max-batches" means we train the model with only 1 max-batches for each epoch, which is usually for sanity test.

androuino commented 4 years ago

I see, thanks for the info @stu1130. However, how could determine if its an ideal time to stop the training? What should I look for at the values during training? Thanks.

androuino commented 4 years ago

@stu1130, could you also explain to me the Pikachu's index.file values?

"img_0.jpg": [4.0, 5.0, 512.0, 512.0, 0.0, 0.604744553565979, 0.40195202827453613, 0.6948338747024536, 0.5354305505752563],

I suppose that 512.0 is the image's height and width. How about the 4.0, and 5.0? What are these values? Thanks.

frankfliu commented 4 years ago

@androuino The pikachu example is a simplified version, it doesn't require image augmentations since it only has one class. However, training a proper ssd model require many image augmentations when loading Record from dataset, you can find python code here: https://github.com/dmlc/gluon-cv/blob/master/gluoncv/data/transforms/presets/ssd.py#L98

The full python training code can be found: https://github.com/dmlc/gluon-cv/blob/master/scripts/detection/ssd/train_ssd.py

siddvenk commented 2 years ago

Closing due to inactivity. Please reopen if you still require help.