caicloud / tensorflow-tutorial

Example TensorFlow codes and Caicloud TensorFlow as a Service dev environment.
2.93k stars 2.08k forks source link

第6章迁移学习疑问:怎样获取瓶颈层张量的名称? #18

Closed j081mm closed 7 years ago

j081mm commented 7 years ago

我想用自己的模型做迁移学习,如何能获得瓶颈层张量的名称? 我在我的原始代码训练模型过程中,加入 Tensor_name=tf.Tensor.name() print"Tensor_name:",Tensor_name 报错: Tensor_name=tf.Tensor.name() TypeError: 'property' object is not callable

原文代码:

Inception-v3模型中代表瓶颈层结果的张量名称。

在谷歌提出的Inception-v3模型中,这个张量名称就是'pool_3/_reshape:0'。

在训练模型时,可以通过tensor.name来获取张量的名称。

BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'

图像输入张量所对应的名称。

JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'`

perhapszzy commented 7 years ago

比如a是一个张量,那么a.name就可以得到他的名字

j081mm commented 7 years ago

多谢作者的指点。还有一个问题,我想将我的训练网络模型保存为.pb文件,迁移到新的图片库中使用。第5章(115页)中,介绍了保存一个运算的方法(add),如果我要保存整个网络,应该怎么进行设置呢?

with tf.Session() as sess:
       sess.run(init_op)
       graph_def=tf.get_default_graph().as_graph_def()
       output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,[?])
       with tf.gfile.GFile("/../../.../network.pb","wb") as f:
               f.write(output_graph_def.SerializeToString())
perhapszzy commented 7 years ago

你为啥要保存整个网络?你应该只是关心部分变量吧?如果你真的关心所有变量,你可以tf.all_variables来获取全部变量列表:https://www.tensorflow.org/versions/r0.10/api_docs/python/state_ops/variable_helper_functions#all_variables

j081mm commented 7 years ago

老师,书本161页,将Inception-v3训练好的网络用作对5种花的类别进行分类。我想换成自己已经在一个数据集上训练好的模型,迁移到对5种花的类别进行分类。需要改变的地方应该是:

BOTTLENECK_TENSOR_NAME='pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME='DecodeJpeg/contents:0'
MODEL_DIR=‘/path/to/model’
MODEL_FILE='classify_image_graph_def.pb'

BOTTLENECK_TENSOR_NAME 您已经在上面说过,JPEG_DATA_TENSOR_NAME因为我还是针对5种花的类别进行分类所以张量名称应该不变,MODEL_DIR应该是我训练网络模型和参数存放文件夹位置,MODEL_FILE 是我的训练网络模型和参数文件(.pb)。 目前MODEL_FILE这个文件还不会生成。不知道是不是我理解有错,如果只保留训练好的参数,再作迁移。是不是又得重新构架类似的网络,参数才会有意义?

perhapszzy commented 7 years ago

MODEL_DIR 是预先训练好的模型,在这个目录下要求有文件预先训练好的文件classify_image_graph_def.pb。这个不是新的模型文件,而是已经训练好的模型。

j081mm commented 7 years ago

对,我就是想预先训练一个classify_image_graph_def.pb文件。 但是不知道在源代码上用哪些函数,能保存下来classify_image_graph_def.pb

perhapszzy commented 7 years ago

这个文件是google训练导出的,书上有介绍如何下载这个文件