eaplatanios / tensorflow_scala

TensorFlow API for the Scala Programming Language
http://platanios.org/tensorflow_scala/
Apache License 2.0
937 stars 95 forks source link

Make the tensor representation of scalar shapes consistent with upstream #18

Closed sbrunk closed 6 years ago

sbrunk commented 6 years ago

Trying the random distribution ops with default arguments gave me the following error:

scala> tf.randomNormal()
java.lang.IllegalArgumentException: Shape must be rank 1 but is rank 0 for 'RandomNormal_4/RandomNormal' (op: 'RandomStandardNormal') with input shapes: [].
  at org.platanios.tensorflow.jni.Op$.finish(Native Method)
  at org.platanios.tensorflow.api.ops.Op$Builder.$anonfun$build$1(Op.scala:1024)
  at org.platanios.tensorflow.api.package$.using(package.scala:58)
  at org.platanios.tensorflow.api.ops.Op$Builder.build(Op.scala:1001)
  at org.platanios.tensorflow.api.ops.Random.$anonfun$randomNormal$1(Random.scala:109)
  at scala.util.DynamicVariable.withValue(DynamicVariable.scala:58)
  at org.platanios.tensorflow.api.ops.Op$.createWithNameScope(Op.scala:709)
  at org.platanios.tensorflow.api.ops.Random.randomNormal(Random.scala:101)
  at org.platanios.tensorflow.api.ops.Random.randomNormal$(Random.scala:96)
  at org.platanios.tensorflow.api.package$API$.randomNormal(package.scala:80)
  ... 39 elided

This also happens with other ops like tf.fill.

Apparently, in the Python API the tensor representation of a scalar shape is a zero length 1-D tensor:

>>> import tensorflow as tf   
>>> tf.convert_to_tensor(tf.TensorShape(()))
<tf.Tensor 'shape_as_tensor_8:0' shape=(0,) dtype=int32>

This is also expected by the native API so I think it makes sense to do the same if it doesn't cause other issues.