eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 95 forks source link

Estimator: infer() method returns empty iterator #119

Closed mandar2812 closed 5 years ago

mandar2812 commented 6 years ago

The Estimator.infer() method returns a tensor as expected when used in the following way.

val preds: Tensor = estimator.infer(() => features)

But when used on a Dataset

val preds: Tensor = estimator.infer(() => tf.data.TensorSlicesDataset(features))

it returns an empty iterator Iterator[(Tensor, ModelInferenceOutput)], instead of the expected non-empty result.

eaplatanios commented 6 years ago

@mandar2812 That should not be happening. What is the dimensionality of features and what does each dimension correspond to?

mandar2812 commented 6 years ago

@eaplatanios, for example


val dataSet = CIFARLoader.load(Paths.get(tempdir.toString()), CIFARLoader.CIFAR_10)

val (model, estimator) = tf.createWith(graph = Graph()) {
...
}

val features = dataSet.testImages

//Works
estimator.infer(features)

//Not
estimator.infer(tf.data.TensorSlicesDataset(dataSet.testImages).batch(128))

This is what I get ...

DynaML>estimator.infer(() => tf.data.TensorSlicesDataset(dataSet.testImages).batch(128)) 
res1: Iterator[(Tensor, core.client.Fetchable.outputFetchable.ResultType)] = empty iterator

DynaML>res1.next 
res4: (Tensor, core.client.Fetchable.outputFetchable.ResultType) = (null, null)

DynaML>res1.hasNext 
res5: Boolean = false

DynaML>estimator.infer(() => tf.data.TensorSlicesDataset(dataSet.testImages)) 
cmd2.sc:1: could not find implicit value for parameter ev: org.platanios.tensorflow.api.learn.estimators.Estimator.SupportedInferInput[org.platanios.tensorflow.api.ops.io.data.TensorSlicesDataset[org.platanios.tensorflow.api.Tensor,org.platanios.tensorflow.api.tensorDataHelper.OutputType,org.platanios.tensorflow.api.tensorDataHelper.DataTypes,org.platanios.tensorflow.api.tensorDataHelper.Shapes],InferOutput,org.platanios.tensorflow.api.Tensor,org.platanios.tensorflow.api.Output,org.platanios.tensorflow.api.DataType,org.platanios.tensorflow.api.Shape,ModelInferenceOutput]
val res2 = estimator.infer(() => tf.data.TensorSlicesDataset(dataSet.testImages))
                          ^
Compilation Failed

DynaML>estimator.infer(() => dataSet.testImages) 
res2: core.client.Fetchable.outputFetchable.ResultType = FLOAT32[10000, 10]

DynaML>res2.summarize() 
res3: String = """FLOAT32[10000, 10]
[[-1.5778543, 0.25482517, 3.183153, ..., 2.175755, -0.9906178, 0.36206508],
 [3.5208342, 5.989723, 2.9353628, ..., 1.4910047, 6.147928, 6.4471908],
 [1.2141892, 3.2512522, 2.784367, ..., 1.7188921, 3.6973124, 4.417762],
 ...,
 [-2.1187825, 1.2255864, 1.4052767, ..., 1.559977, -1.3915741, 1.664874],
 [-1.4054754, 1.9207716, 1.6349213, ..., 1.7796209, -0.2868932, 1.5499349],
 [-1.9703177, 0.58942145, 2.7025428, ..., 3.4534814, -1.3394758, 0.57645905]]"""

Estimator Source

Looking in InMemoryEstimator.scala lines 237 onwards

try {
        ev.convertFetched(new Iterator[(IT, ModelInferenceOutput)] {
          override def hasNext: Boolean = !session.shouldStop
          override def next(): (IT, ModelInferenceOutput) = {
            try {
              // TODO: !!! There might be an issue with the stop criteria here.
              session.removeHooks(currentTrainHooks ++ evaluateHooks)
              val output = session.run(fetches = (inferenceOps.input, inferenceOps.output))
              session.addHooks(currentTrainHooks ++ evaluateHooks)
              output
            } catch {
              case _: OutOfRangeException =>
                session.setShouldStop(true)
                // TODO: !!! Do something to avoid this null pair.
                (null.asInstanceOf[IT], null.asInstanceOf[ModelInferenceOutput])
              case t: Throwable =>
                stopHook.updateCriteria(stopCriteria)
                session.closeWithoutHookEnd()
                throw t
            }
          }
        })
      }

Seems the following code block is executed.

case _: OutOfRangeException =>
  session.setShouldStop(true)
  // TODO: !!! Do something to avoid this null pair.
  (null.asInstanceOf[IT], null.asInstanceOf[ModelInferenceOutput])

I am getting similar results with FileBasedEstimator.scala also, the above is done using tf-scala 0.2.4

eaplatanios commented 6 years ago

@mandar2812 Is there any way to test this on 0.3.0-SNAPSHOT. I plan to release 0.3.0 this week and it would be good to know if this still happens. Also, sorry for the late response, but I just got back to work last week.

mandar2812 commented 6 years ago

@eaplatanios No problem! So the thing is it will take me some effort to refactor my code to suit the 0.3.0 release of TF-Scala. I assume this has the Tensor[DataType] kind of API. I will try to create a separate branch in my code to perform refactoring, but I cant guarantee that I will be able to test this by this week.

eaplatanios commented 5 years ago

@mandar2812 FYI 0.3.0 is out and also 0.4.0-SNAPSHOT is published which is a much major change. You can see the changes in PR #131 . :)

eaplatanios commented 5 years ago

@mandar2812 I'm sorry this took so long but I finally got the time to look into this and found the issue. It has now been fixed on the master branch.