tensorflow / hub

A library for transfer learning by reusing parts of TensorFlow models.
https://tensorflow.org/hub
Apache License 2.0
3.49k stars 1.67k forks source link

Re-training a Hub module as a Siamese network #134

Closed catalla closed 6 years ago

catalla commented 6 years ago

87

When re-training a hub module (universal sentence encoder) as a Siamese network, are gradients guaranteed to propagate correctly even if the module is run twice shown above?

We cannot specify 'reuse=True' so, it seems unclear whether applying the module a second time will erase the state after the first time.

m = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2", trainable=True)
embedding1 = m(query_1)
embedding2 = m(query_2)

loss = tf.square(embedding1 - embedding2) # Some loss function here
optimizer = tf.train.GradientDescentOptimizer().minimize(loss)
sess.run(optimizer)
andresusanopinto commented 6 years ago

Hi @catalla,

Yes the gradients will propagate via both call() into the same variables as the variables of a module are created during the constructor and they are shared across call().

From the "Applying a module" section from the user-guide:

"""If this involves variables with trained weights, these are shared between all applications."""

Let us know if you have ideas on how to improve that documentation to make it more clear.

catalla commented 6 years ago

"""If this involves variables with trained weights, these are shared between all applications."""

This makes the behavior clear for applying a model with trained weights. However, it could be clarifying to mention that the same applies for gradients when re-training a hub module with trainable weights.

Thanks for the response!

Killthebug commented 5 years ago

I'm trying to duplicate the method proposed by @catalla. Though when I try train the model, it prompts me with the error : ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients... I've set trainable=True. Is there any element I'm missing or should account for (I've always developed in PyTorch and am just experimenting with Tensorflow, so please pardon any novice mistakes on my end.)