eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

Gradients: can't handle non-existing gradients with gateGradients = true #44

Closed carlo-veezoo closed 6 years ago

carlo-veezoo commented 6 years ago

Hi there!

I found a bug when using tf.Gradients.gradients(..., gateGradients=true) on a graph that uses a tf.concatenate operation. I found the bug because I wanted to use the AdaGrad optimizer.

A simple example to reproduce the issue:

import org.platanios.tensorflow.api._

object GatedGradientBugReproducer {
  def main(args: Array[String]): Unit = {
    val a = tf.variable("a", FLOAT32, Shape(1))
    val concat = tf.concatenate(Seq(a, Tensor(1.0f)))
    val loss = concat.sum()

    tf.Gradients.gradients(Seq(loss), Seq(a), gateGradients = true) // Throws NullPointerException
  }
}

I dug a bit in the code, the exception is thrown when tf wants to create a ControlFlow.tuple with the inputs containing a null value.

Thanks a lot for having a look!

Best, Carlo

eaplatanios commented 6 years ago

@csaladin94 Thanks for pointing this out! I think this has already been fixed in the master branch although I haven't released new artifacts on Sonatype for it. I will release new artifacts tomorrow and then your code should probably run fine. :)

carlo-veezoo commented 6 years ago

@eaplatanios Thanks for the fast answer, glad to hear that! I'll come back to you as soon as I have tried my code with the update.

carlo-veezoo commented 6 years ago

@eaplatanios Just compiled the project myself from the master branch, can confirm that it the bug is fixed.

eaplatanios commented 6 years ago

@csaladin94 Great, thanks for checking it! :)