GPflow / check_shapes

Library for annotating and checking tensor shapes.
Apache License 2.0
7 stars 1 forks source link

TensorFlow cannot compile shape-checked method with explicit input signature. #1

Open jesnie opened 2 years ago

jesnie commented 2 years ago
import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes

class A:

    def f(self, x):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def g(self, x):
        return x + 2

a = A()
specs = [tf.TensorSpec(shape=None, dtype=tf.int32)]
f = tf.function(a.f)
f2 = tf.function(a.f, input_signature=specs)
g = tf.function(a.g)
g2 = tf.function(a.g, input_signature=specs)
x = tf.constant(7)

f(x)  # Good
f2(x)  # Good
g(x)  # Good
g2(x)  # Breaks...
Corwinpro commented 2 years ago

Hi @jesnie ,

I am hitting a similar problem. Is it somehow related to the one reported in this issue?

Having an optional argument together with a tf.function wrapper is failing.

import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes

class A:
    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def foo(self, x, opt=None):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar(self, x, opt=None):
        return tf.function(self.foo)(x, opt)

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def foo_2(self, x):
        return x + 2

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar_2(self, x):
        return tf.function(self.foo_2)(x)

a = A()

a.foo_2(2.0)  # OK
a.bar_2(2.0)  # OK

a.foo(2.0)  # OK
a.bar(2.0)  # Not OK

I noticed that changing to this fixes it:

    @check_shapes(
        "x: [batch...]",
        "return: [batch...]",
    )
    def bar(self, x, opt=None):
        return tf.function(self.foo)(x=x, opt=opt)  # added key words
jesnie commented 2 years ago

Which version of TensorFlow are you using? I know some of the earlier versions are struggling with optional parameters: https://github.com/GPflow/GPflow/blob/fda83683483429de5eda996ba2f98c0400b987cf/tests/gpflow/experimental/check_shapes/test_integration.py#L180

Corwinpro commented 2 years ago

2.4 it is. Thanks for pointing this out. Are there any known fixes for that?

jesnie commented 2 years ago

I haven't made the effort to look into it. :shrug: