deeplearning4j / deeplearning4j-examples

Deeplearning4j Examples (DL4J, DL4J Spark, DataVec)
http://deeplearning4j.konduit.ai
Other
2.45k stars 1.82k forks source link

how to predict in CNNTextClassifier.java? #587

Open yangjun023 opened 6 years ago

yangjun023 commented 6 years ago

1.the model use ReshapePreProcessor

 graphConfBuilder.addVertex(reshapedForConvName,
            new PreprocessorVertex(new ReshapePreProcessor(
                new int[] {batchSize * docLength, embeddingsDim},
                new int[]{batchSize, 1, docLength, embeddingsDim})),
            embeddedName);
  1. when predict,there is only one sample ,there will be a error "Mis matched lengths",how to fix the bug?
raver119 commented 6 years ago

Can you please show full error message, and what you're feeding there?

yangjun023 commented 6 years ago

1.the log as follows:

Exception in thread "main" java.lang.IllegalStateException: Mis matched lengths: [640000] != [5000]
    at org.nd4j.linalg.util.LinAlgExceptions.assertSameLength(LinAlgExceptions.java:40)
    at org.nd4j.linalg.api.ops.BaseTransformOp.<init>(BaseTransformOp.java:47)
    at org.nd4j.linalg.api.ops.impl.transforms.Set.<init>(Set.java:24)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.assign(BaseNDArray.java:1267)
    at org.nd4j.linalg.api.ndarray.BaseNDArray.reshape(BaseNDArray.java:3753)
    at org.deeplearning4j.nn.graph.vertex.impl.ReshapeVertex.doForward(ReshapeVertex.java:77)
    at org.deeplearning4j.nn.graph.ComputationGraph.feedForward(ComputationGraph.java:1464)
    at org.deeplearning4j.nn.graph.ComputationGraph.silentOutput(ComputationGraph.java:1555)
    at org.deeplearning4j.nn.graph.ComputationGraph.output(ComputationGraph.java:1544)
    at org.deeplearning4j.nn.graph.ComputationGraph.outputSingle(ComputationGraph.java:1579)

the parameters are: batch=128, doclength=50

  1. I have not find any predict code in CNNTextClassifier.java,Do u know how to write the predict code ?
qianyiwei commented 6 years ago

outputSingle,this will return the p of labels