tensorflow / graphics

TensorFlow Graphics: Differentiable Graphics Layers for TensorFlow
Apache License 2.0
2.75k stars 361 forks source link

`tensorflow_graphics.shape.check_static(...)` throwing error in TF2 #625

Open m4ttr4ymond opened 3 years ago

m4ttr4ymond commented 3 years ago

I was writing a data augmentation layer for a PointNet implementation and ran into what appears to be a bug in tensorflow_graphics.shape.check_static(...), as seen on this line.

Offending layer:

class RandomRot(Layer):
  def __init__(self):
    super(RandomRot, self).__init__()

  def build(self, input_shape):
    self.s = tf.constant([input_shape[-1],])

  def call(self, inputs, training=None):
    if not training: return inputs

    r = tf.random.uniform(
      shape=self.s,
      minval=0,
      maxval=6.28,
    )

    return tf.linalg.matmul(inputs,from_euler(r))

Error message:

AttributeError: in user code:

    <ipython-input-135-d11754641da6>:81 call  *
        self.x = self.r(self.x,training)
    <ipython-input-130-07bfe7ac5ab9>:25 call  *
        return tf.linalg.matmul(inputs,from_euler(r))
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/geometry/transformation/rotation_matrix_3d.py:201 from_euler  *
        shape.check_static(
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:206 check_static  *
        if _get_dim(tensor, axis) != value:
    /usr/local/lib/python3.6/dist-packages/tensorflow_graphics/util/shape.py:135 _get_dim  *
        return tensor.shape[axis].value

    AttributeError: 'int' object has no attribute 'value'

It appears that check_static is expecting each element from .shape to be a tensor, but in TF2 they're just ints. If I comment out check_static from from_euler, the function works fine. Strangely enough, it seems to work fine for tensors in eager execution, and only seems to throw errors when using Dataset objects with graph compilation.

NicolayP commented 2 years ago

Any update on this? Get the same error. Here is the simplest code that triggers the error.

quat = tf.constant([[0., 0., 0., 1.]], dtype=tf.float64)

euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(euler)

@tf.function
def rot(quat):
    euler = tfg.geometry.transformation.euler.from_quaternion(quat)
print(rot(quat))

This triggers the same error as @m4ttr4ymond