tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
785 stars 193 forks source link

fix #526 #527

Closed nfeybesse closed 4 months ago

Craigacp commented 4 months ago

Can you add a test which triggers this bug to the cross entropy tests - https://github.com/tensorflow/java/blob/master/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java. I think this used to work so I worry it's due to a TF upgrade and we didn't catch it with tests.

nfeybesse commented 4 months ago

No, the problem is older, and it is probably the dynamic batch size which triggers the problems. I will try to do a test case

nfeybesse commented 4 months ago

@Test public void testCategoricalCrossEntopyWithDynamicBatchSize() { try (Graph graph = new Graph()) { Ops tf = Ops.create(graph); Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3))); Operand yTrue = tf.reshape(tf.constant(new float[] { 1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f }), tf.array(3, 3)); CategoricalCrossentropy instance = new CategoricalCrossentropy(true); Operand loss = instance.call(tf, yTrue, yPred);// Throw TFInvalidArgument Exception without fix try (Session session = new Session(graph); TFloat32 result = (TFloat32) session.runner().feed(yPred, TFloat32.tensorOf(Shape.of(3, 3), DataBuffers.of(new float[] { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }))).fetch(loss).run().get(0)) { if (Math.abs(0.5514477f - result.getFloat()) > 0.01) throw new IllegalStateException("Invalid result :" + result.getFloat()); } } }

nfeybesse commented 4 months ago

I am confused because I have the impression that there has not yet been a test carried out by feeding a model with batches of dynamic size. I know from experience that it is largely possible, but that you have to track down a few small bugs. How would you integrate my test so that it would be suitable?

Craigacp commented 4 months ago

Add it next to the other tests for that loss. If there are more issues then let's fix them.

More of the framework was in flight a couple of years ago, but we didn't get all of it merged, so I assume that some of those things were tested in the original codebase before it was broken up into smaller PRs.

Craigacp commented 4 months ago

Thanks