Open ZefengHan opened 5 years ago
关于测试文件的编写还是存在疑问,根据您给的博客链接我只会把checkpoint加载进来,后面的传参还是不太会,該传进那些参数不是很清楚,根据博客链接里面,没看见测试集如何加载进来的,希望您能够给予解答。 import tensorflow as tf import numpy as np
graph = tf.Graph() with graph.as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) session_conf.gpu_options.allow_growth=True session_conf.gpu_options.per_process_gpu_memory_fraction = 0.9 # 配置gpu占用率 sess = tf.Session(config=session_conf) with sess.as_default(): checkpoint_file = tf.train.latest_checkpoint("C:/Users/韩泽峰/Desktop/textClassifier-master/model/transformer3classfier/") saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file)
inputX = graph.get_operation_by_name("inputX") dropoutKeepProb = graph.get_operation_by_name("dropoutKeepProb") embeddedPosition = graph.get_operation_by_name("embeddedPosition") # 获得输出的结果 pred_all = graph.get_operation_by_name("inputY")
关于transformer的测试文件编写还是存在困惑,望作者老师能给解答一下。
你好,已经在博客里面回复你了
谢谢您啦!
博客地址能给学习一下嘛 我翻完了没发现 ,刚学NLP 不会写测试文件读完模型之后传参也是挺模糊的
关于测试文件的编写还是存在疑问,根据您给的博客链接我只会把checkpoint加载进来,后面的传参还是不太会,該传进那些参数不是很清楚,根据博客链接里面,没看见测试集如何加载进来的,希望您能够给予解答。 import tensorflow as tf import numpy as np
graph = tf.Graph() with graph.as_default():
获得需要喂给模型的参数,输出的结果依赖的输入值