tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
813 stars 200 forks source link

The loss CatagoricalCrossEntropy is currently unusable in framework #526

Closed nfeybesse closed 7 months ago

nfeybesse commented 7 months ago

The use of class org.tensorflow.framework.losses.CatagoricalCrossEntropy.java ended systematically with this issue :

Exception in thread "main" org.tensorflow.exceptions.TFInvalidArgumentException: Shape must be rank 1 but is rank 0 for '{{node SoftmaxCrossEntropyWithLogits/Slice}} = Slice[Index=DT_INT64, T=DT_INT64](SoftmaxCrossEntropyWithLogits/Shape, SoftmaxCrossEntropyWithLogits/Sub, SoftmaxCrossEntropyWithLogits/Const)' with input shapes: [2], [], []. at org.tensorflow.internal.c_api.AbstractTF_Status.throwExceptionIfNotOK(AbstractTF_Status.java:87) at org.tensorflow.GraphOperationBuilder.finish(GraphOperationBuilder.java:461) at org.tensorflow.GraphOperationBuilder.build(GraphOperationBuilder.java:100) at org.tensorflow.GraphOperationBuilder.build(GraphOperationBuilder.java:71) at org.tensorflow.op.core.Slice.create(Slice.java:90) at org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits.flattenOuterDims(SoftmaxCrossEntropyWithLogits.java:184) at org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(SoftmaxCrossEntropyWithLogits.java:116) at org.tensorflow.framework.op.NnOps.softmaxCrossEntropyWithLogits(NnOps.java:145) at org.tensorflow.framework.losses.Losses.categoricalCrossentropy(Losses.java:253) at org.tensorflow.framework.losses.CategoricalCrossentropy.call(CategoricalCrossentropy.java:256) at org.tensorflow.framework.losses.impl.AbstractLoss.call(AbstractLoss.java:69) at org.genericsystem.keras.losses.SparseCategoricalCrossentropyLoss.internal(SparseCategoricalCrossentropyLoss.java:61) at org.genericsystem.keras.losses.NativeLossLayer.internalCall(NativeLossLayer.java:23) at org.genericsystem.keras.layers.Layer.call(Layer.java:31) at org.genericsystem.keras.layers.Layer.call(Layer.java:23) at org.genericsystem.keras.model.ModelContext.initalize(ModelContext.java:59) at org.genericsystem.keras.model.TrainingContext.initalize(TrainingContext.java:28) at org.genericsystem.keras.model.Model.trainStep(Model.java:260) at org.genericsystem.keras.example.gan.AbstractACGanApp$3.trainStep(AbstractACGanApp.java:67) at org.genericsystem.keras.model.Model.train(Model.java:453) at org.genericsystem.keras.model.Model.fit(Model.java:433) at org.genericsystem.keras.example.gan.AbstractGanApp.lambda$3(AbstractGanApp.java:117) at org.genericsystem.keras.GSKeras.safeExecuteSession(GSKeras.java:164) at org.genericsystem.keras.example.gan.AbstractGanApp.fit(AbstractGanApp.java:109) at org.genericsystem.keras.example.gan.acgan.cifar10.ACGanApp.main(ACGanApp.java:26)

This is due to the method flattenOuterDims in org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits.java which build attribute rankminusone as a scalar instead of an array (rank1)