tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.49k stars 1.67k forks source link

USEv5 throws "InvalidArgumentError: Value for attr 'T' of string..." with trainable Keras layer #437

Closed eduardofv closed 4 years ago

eduardofv commented 5 years ago

I'm getting InvalidArgumentError when using a trainable KerasLayer from USE-large-v5, or USE-v4. Works fine with other modules or if the layer is not trainable. Using TF 2.0, hub 0.7.0 in Colab (but also found in local environments and using not TF Keras).

Check Colab for full example

import tensorflow as tf
import tensorflow_hub as hub

MODEL = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
#MODEL = "https://tfhub.dev/google/universal-sentence-encoder/4"
DIM = 512
IS_TRAINABLE = True #Works OK if False

#works fine with NNLM
#MODEL = "https://tfhub.dev/google/nnlm-en-dim50/2"
#DIM = 50

train_data = ["Hello world", "I'll fix it", "Function instantiation has undefined input shape"]
train_labels = [0, 0, 1]

hub_layer = hub.KerasLayer(MODEL, output_shape=[DIM], input_shape=[], trainable=IS_TRAINABLE, dtype="string")

model = tf.keras.Sequential()
model.add(hub_layer)
model.add(tf.keras.layers.Dense(1))

model.compile(optimizer="adam", loss="binary_crossentropy")
model.summary()
model.fit(train_data, train_labels, epochs=1)

Outputs:

Train on 3 samples
3/3 [==============================] - 79s 26s/sample
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-3-21f3757b499e> in <module>()
----> 1 model.fit(train_data, train_labels, epochs=1)

11 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError:  Value for attr 'T' of string is not in the list of allowed values: float, double, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, uint16, complex128, half, uint32, uint64, variant
    ; NodeDef: {{node Func/_3650}}; Op<name=AddN; signature=inputs:N*T -> sum:T; attr=N:int,min=1; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, ..., DT_COMPLEX128, DT_HALF, DT_UINT32, DT_UINT64, DT_VARIANT]; is_commutative=true; is_aggregate=true>
     [[Func/_3650]]
     [[PartitionedCall/gradients/StatefulPartitionedCall_grad/PartitionedCall/gradients/StatefulPartitionedCall_grad/SymbolicGradient]] [Op:__inference_distributed_function_202975]

Function call stack:
distributed_function
andriikrupka commented 4 years ago

Try to use tensorflow==2.1.0rc0. It seems like already fixed with newest release.

I've checked with latest tf and everything works fine. Also, I've reproduced your issue with tf==2.0.0

Check colab

eduardofv commented 4 years ago

Great! Thanks!