keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.85k stars 19.44k forks source link

Layer does not have a default tensorflow namespace; requires passing `initial_value` in lambda function #18461

Open grasskin opened 1 year ago

grasskin commented 1 year ago

Ran into this with swim_transformer.py example. We create a custom layer that uses numpy ops and a tensorflow variable to store state.

self.relative_position_index = tf.Variable(
    initial_value=tf.convert_to_tensor(relative_position_index), trainable=False
)

Despite being initialized inside the build() method, we do not seem to be creating a tensorflow namespace by default and so the variable cannot be created. Workaround is to pass initial_value as a lambda function:

self.relative_position_index = tf.Variable(
    initial_value=lambda: tf.convert_to_tensor(relative_position_index), trainable=False
)

Full error log:

Could not automatically infer the output shape / dtype of 'swin_transformer' (of type SwinTransformer). Either the `SwinTransformer.call()` method is incorrect, or you need to implement the `SwinTransformer.compute_output_spec()` method. Error encountered:

Argument `initial_value` (Tensor("window_attention/Const:0", shape=(4, 4), dtype=int64)) could not be lifted out of a `tf.function`. (Tried to create variable with name='None'). To avoid this error, when constructing `tf.Variable`s inside of `tf.function` you can create the `initial_value` tensor in a `tf.init_scope` or pass a callable `initial_value` (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`). Please file a feature request if this restriction inconveniences you.

Arguments received by SwinTransformer.call():
  • args=('<KerasTensor shape=(None, 256, 64), dtype=float32, name=keras_tensor_6>',)
  • kwargs=<class 'inspect._empty'>
mehtamansi29 commented 3 weeks ago

Hi @grasskin -

I am successfully able to run swim_transformer.py example in keras 3.5.0 and it is working fine by creating custom layer like this:

self.relative_position_index = keras.Variable(initializer=relative_position_index,shape=relative_position_index.shape,
            dtype="int",trainable=False,)

Attached gist for the reference.