gaussic / text-classification-cnn-rnn

CNN-RNN中文文本分类,基于TensorFlow
MIT License
4.17k stars 1.47k forks source link

同一进程里启动两个模型的实例 #23

Closed commissarster closed 4 years ago

commissarster commented 6 years ago

我使用当前算法分别训练了两个模型: A. 内容分类的分类模型。 B. 垃圾文章的分类模型。

我的应用场景是需要判断一篇文章分类以及是否为垃圾文章。

当前的办法是,同时加载两个模型,先后对文章进行识别,但是我发现在在当前的算法在同一进程中依次加载A和B模型时,B模型的加载会报 ValueError: Variable embedding already exists, disallowed. 的错误。我可以理解这个错误是由于是由于A模型已经初始化变量了,且这个变量是全局性的。所以,无法进行B模型的加载。

想了解 一下这种场景要如何使用当前算法?

我想到的解决方案是在算法实际化时,增加一个variable_scope,即在cnn_model.py的34行增加一个变量空间,这样,所有变量都是在这个空间下。如A模型,的变量空间是a_ns。B模型的空间是b_ns。但这种方案有一个问题: 原来的模型需要重新训练,且训练时就要把空间名称确定。

不知道你们是否也遇到过这种情况?

hao1032 commented 6 years ago

@commissarster 我的代码你可能无法直接使用,我这边是成功使用了2个model

import os import tensorflow as tf import tensorflow.contrib.keras as kr

from tf_model import TCNNConfig, TextCNN from data.loader import read_category, read_vocab

class CnnModel: def init(self, name): self.config = TCNNConfig(name) self.categories, self.cat_to_id = read_category(self.config) self.words, self.word_to_id = read_vocab(self.config) self.config.vocab_size = len(self.words) self.vocab_dir = 'data\{}_vocab.txt'.format(name) self.save_dir = '{}\checkpoints\textcnn\{}'.format(os.getcwd(), name) self.save_path = os.path.join(self.save_dir, 'best_validation') # 最佳验证结果保存路径

    graph = tf.Graph()
    with graph.as_default():
        self.session = tf.Session()
        with self.session.as_default():
            self.model = TextCNN(self.config)
            self.session.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess=self.session, save_path=self.save_path)  # 读取保存的模型

def predict(self, content):
    print(content)
    data = [self.word_to_id[x] for x in content if x in self.word_to_id]

    feed_dict = {
        self.model.input_x: kr.preprocessing.sequence.pad_sequences([data], self.config.seq_length),
        self.model.keep_prob: 1.0
    }

    y_pred_cls = self.session.run(self.model.y_pred_softmax, feed_dict=feed_dict)
    for index, item in enumerate(y_pred_cls[0]):
        print("{}, {:.5%}".format(self.categories[index], item))
    return '\n'

class Predict: def init(self): self.emo_model = CnnModel('emotion') self.cls_model = CnnModel('classification')

def run(self):
    pass
    self.emo_model.predict('标题党')
    self.cls_model.predict('标题党')

p = Predict() p.run()

输出 标题党 正面, 0.00001% 负面, 99.99999% 标题党 内容重复过时, 0.00001% 虚构造假, 0.00000% 标题不符, 99.99999% 内容低质, 0.00000%

HuitMahoon commented 5 years ago

用了@commissarster 的方法,没起作用; @hao1032 的方法解决了问题。新建一张图,在图里实例化一下TextCNN。

    graph = tf.Graph()
    with graph.as_default():
        self.session = tf.Session()
        with self.session.as_default():
            self.model = TextCNN(self.config, weiboType)
            self.session.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            saver.restore(sess = self.session,save_path=os.path.join('./checkpoints/textcnn', 'best_validation'))  # 最佳验证结果保存路径
Ai-is-light commented 5 years ago

@HuitMahoon would you mind sharing more or whole code about your method

HuitMahoon commented 5 years ago

@Ai-is-light Just replace a few original source code in init() with what i have used. It worked for me.