Open jesnie opened 2 years ago
Hi @jesnie ,
I am hitting a similar problem. Is it somehow related to the one reported in this issue?
Having an optional argument together with a tf.function
wrapper is failing.
import tensorflow as tf
from gpflow.experimental.check_shapes import check_shapes
class A:
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def foo(self, x, opt=None):
return x + 2
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar(self, x, opt=None):
return tf.function(self.foo)(x, opt)
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def foo_2(self, x):
return x + 2
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar_2(self, x):
return tf.function(self.foo_2)(x)
a = A()
a.foo_2(2.0) # OK
a.bar_2(2.0) # OK
a.foo(2.0) # OK
a.bar(2.0) # Not OK
I noticed that changing to this fixes it:
@check_shapes(
"x: [batch...]",
"return: [batch...]",
)
def bar(self, x, opt=None):
return tf.function(self.foo)(x=x, opt=opt) # added key words
Which version of TensorFlow are you using? I know some of the earlier versions are struggling with optional parameters: https://github.com/GPflow/GPflow/blob/fda83683483429de5eda996ba2f98c0400b987cf/tests/gpflow/experimental/check_shapes/test_integration.py#L180
2.4 it is. Thanks for pointing this out. Are there any known fixes for that?
I haven't made the effort to look into it. :shrug: