eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
936 stars 96 forks source link

Gradients.gradients can't handle unused null gradients #147

Closed Spiess closed 5 years ago

Spiess commented 5 years ago

I am using a custom operation with three result outputs. When I use Gradients.gradients to obtain the gradient with respect to the first of these outputs, the program crashes with a NullPointerException as the initial gradient for one of the other two outputs is null.

The gradient for this output is never used by my custom GradientFn, but since the change from GradientRegistry to GradientFn all output gradients are cast to their respective type in Op.scala line 1501, even if they are null (in which case a NullPointerException is thrown).

What I'm trying used to work with TensorFlow Scala 0.2.4 and works as intended if I compile TensorFlow Scala 0.4.2-SNAPSHOT after changing line 1501 in org/platanios/tensorflow/api/ops/Op.scala from

outputs.map(_.toOutput.asInstanceOf[Output[T]])

to

outputs.map(o => if (o == null) null else o.toOutput.asInstanceOf[Output[T]])
eaplatanios commented 5 years ago

@Spiess Thanks a lot for catching this! I'll push a fix and include it in the next release. :)