tensorflow / tensor2tensor

Library of deep learning models and datasets designed to make deep learning more accessible and accelerate ML research.
Apache License 2.0
15.38k stars 3.48k forks source link

Building model body multiple times when calling model_fn() multiple times #1778

Open lemmonation opened 4 years ago

lemmonation commented 4 years ago

Description

I am constructing a meta learning framework on tensor2tensor which requires calling model_fn() multiple times. And I find that the framework will build the model body multiple times, even with reuse=True flag. The pseudocode is as follows:

def model_fn_raw(self, feature): # the origin model_fn
      ...
     tf.print("scope:", tf.get_variable_scope(), "name:", tf.get_variable_scope().name)
     log_info("Building model body")
     body_out = self.body(feature)
      ... 

def model_fn(self, feature): # wrap the origin model_fn with meta learning part
     # step1: call model_fn_raw() the first time to compute loss
     with tf.variable_scope(tf.get_variable_scope()):
           _, loss = model_fn_raw(self, feature)
     # step2: meta learning part, update params and assign it to all variables
     updated_para = updated_para_once(loss)
     assign_para_op = tf.assign(tf.trainable_variables(), updated_para)
     with tf.control_dependencies(list(assign_para_op)):
         # step3: call model_fn_raw() the second time to compute loss with new params
         # set the reuse flag as True
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
               logits, loss = model_fn_raw(self, feature)
     restore_origin_params()
     return logits, loss

In this way, when excuting model_fn(), the model body should be built only once because I set reuse=True at the second time. However the model body is still built twice when I running the code, with printed logs as follows:

INFO:tensorflow:Transforming feature 'targets' with symbol_modality_10152_256.targets_bottom
INFO:tensorflow:Building model body
:::MLPv0.5.0 transformer ...
...(Other logs of model components)
INFO:tensorflow:Transforming body output with symbol_modality_10152_256.top
INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_10152_256.bottom
# (this should be the end. but the same logs are printed again)
INFO:tensorflow:Transforming feature 'targets' with symbol_modality_10152_256.targets_bottom
INFO:tensorflow:Building model body
:::MLPv0.5.0 transformer ...
...
INFO:tensorflow:Transforming body output with symbol_modality_10152_256.top
INFO:tensorflow:Transforming feature 'inputs' with symbol_modality_10152_256.bottom

I also have checked and printed variable_scope in model_fn_raw(). And I find that the scope name in the first and second call is the same, but the address is not. For example,

# scope address and name printed in the first call
scope: <tensorflow.python.ops.variable_scope.VariableScope object at 0x7fcb13928ba8>
name: transformer/body
# in the second call
scope: <tensorflow.python.ops.variable_scope.VariableScope object at 0x7fcb137747f0> 
name: transformer/body

The name is consistent but the address is changed, which means that the variable scope is not the same one in the first and second call.

So I open this issue and hope someone can help me to address this problem. How to construct the model body only once and reuse it in multiple calls? Is it because of the assign ops in step 2 that leads to the multiple constructions?

Environment information

OS: CentOS 7.5

$ pip freeze | grep tensor
tensor2tensor==1.10.0
mesh-tensorflow==0.1.7
tensorboard==1.12.2
tensorflow-gpu==1.12.0

$ python -V
Python 3.6.2 
ArtemisZGL commented 4 years ago

hello, do you have any idea about how to use the MAML in tensor2tensor?