Closed commissarster closed 4 years ago
@commissarster 我的代码你可能无法直接使用,我这边是成功使用了2个model
import os import tensorflow as tf import tensorflow.contrib.keras as kr
from tf_model import TCNNConfig, TextCNN from data.loader import read_category, read_vocab
class CnnModel: def init(self, name): self.config = TCNNConfig(name) self.categories, self.cat_to_id = read_category(self.config) self.words, self.word_to_id = read_vocab(self.config) self.config.vocab_size = len(self.words) self.vocab_dir = 'data\{}_vocab.txt'.format(name) self.save_dir = '{}\checkpoints\textcnn\{}'.format(os.getcwd(), name) self.save_path = os.path.join(self.save_dir, 'best_validation') # 最佳验证结果保存路径
graph = tf.Graph()
with graph.as_default():
self.session = tf.Session()
with self.session.as_default():
self.model = TextCNN(self.config)
self.session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess=self.session, save_path=self.save_path) # 读取保存的模型
def predict(self, content):
print(content)
data = [self.word_to_id[x] for x in content if x in self.word_to_id]
feed_dict = {
self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
self.model.keep_prob: 1.0
}
y_pred_cls = self.session.run(self.model.y_pred_softmax, feed_dict=feed_dict)
for index, item in enumerate(y_pred_cls[0]):
print("{}, {:.5%}".format(self.categories[index], item))
return '\n'
class Predict: def init(self): self.emo_model = CnnModel('emotion') self.cls_model = CnnModel('classification')
def run(self):
pass
self.emo_model.predict('标题党')
self.cls_model.predict('标题党')
p = Predict() p.run()
输出 标题党 正面, 0.00001% 负面, 99.99999% 标题党 内容重复过时, 0.00001% 虚构造假, 0.00000% 标题不符, 99.99999% 内容低质, 0.00000%
用了@commissarster 的方法,没起作用; @hao1032 的方法解决了问题。新建一张图,在图里实例化一下TextCNN。
graph = tf.Graph()
with graph.as_default():
self.session = tf.Session()
with self.session.as_default():
self.model = TextCNN(self.config, weiboType)
self.session.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess = self.session,save_path=os.path.join('./checkpoints/textcnn', 'best_validation')) # 最佳验证结果保存路径
@HuitMahoon would you mind sharing more or whole code about your method
@Ai-is-light Just replace a few original source code in init() with what i have used. It worked for me.
我使用当前算法分别训练了两个模型: A. 内容分类的分类模型。 B. 垃圾文章的分类模型。
我的应用场景是需要判断一篇文章分类以及是否为垃圾文章。
当前的办法是,同时加载两个模型,先后对文章进行识别,但是我发现在在当前的算法在同一进程中依次加载A和B模型时,B模型的加载会报 ValueError: Variable embedding already exists, disallowed. 的错误。我可以理解这个错误是由于是由于A模型已经初始化变量了,且这个变量是全局性的。所以,无法进行B模型的加载。
想了解 一下这种场景要如何使用当前算法?
我想到的解决方案是在算法实际化时,增加一个variable_scope,即在cnn_model.py的34行增加一个变量空间,这样,所有变量都是在这个空间下。如A模型,的变量空间是a_ns。B模型的空间是b_ns。但这种方案有一个问题: 原来的模型需要重新训练,且训练时就要把空间名称确定。
不知道你们是否也遇到过这种情况?