NTMC-Community / MatchZoo

Facilitating the design, comparison and sharing of deep text matching models.
Apache License 2.0
3.85k stars 900 forks source link

DSSM wordhashing 问题 #804

Closed wangshansong1 closed 5 years ago

wangshansong1 commented 5 years ago

我使用自己构建的中文数据集训练DSSM模型, wordhash 的过程中,用掉了几乎100G的内存,我赶快杀死了进程。 请问,这是这种嵌入方式的必然结果吗,或者怎样设置?

uduse commented 5 years ago

481 may help with your memory issue.

wangshansong1 commented 5 years ago

我按照 #481 的建议,将代码改动后,已经能够成功的进行一次epochs,但是在调用callbacks时,他出现如下报错:

Traceback (most recent call last):
  File "C:/Users/wansy132/PycharmProjects/MatchZoo/src/DSSM.py", line 57, in <module>
    use_multiprocessing=False)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\matchzoo\engine\base_model.py", line 276, in fit_generator
    verbose=verbose, **kwargs
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
    initial_epoch=initial_epoch)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 260, in fit_generator
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 152, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\matchzoo\engine\callbacks.py", line 65, in on_epoch_end
    self._batch_size)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\matchzoo\engine\base_model.py", line 324, in evaluate
    y_pred = self.predict(x, batch_size)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\matchzoo\engine\base_model.py", line 397, in predict
    return self._backend.predict(x=x, batch_size=batch_size)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\engine\training.py", line 1441, in predict
    x, _, _ = self._standardize_user_data(x)
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\engine\training.py", line 579, in _standardize_user_data
    exception_prefix='input')
  File "C:\Users\wansy132\software\Anaconda3\lib\site-packages\keras\engine\training_utils.py", line 145, in standardize_input_data
    str(data_shape))
ValueError: Error when checking input: expected text_left to have shape (38755,) but got array with shape (1,)

这里是我的代码:

if __name__ == '__main__':
    ranking_task = matchzoo.tasks.Ranking(loss=matchzoo.losses.RankCrossEntropyLoss(num_neg=4))
    ranking_task.metrics = [
        matchzoo.metrics.NormalizedDiscountedCumulativeGain(k=3),
        matchzoo.metrics.NormalizedDiscountedCumulativeGain(k=5),
        matchzoo.metrics.MeanAveragePrecision()
    ]

    train_raw = load_data(stage='train', task=ranking_task,file_path='../dataset/bh_train.csv')
    test_raw = load_data(stage='test', task=ranking_task,file_path='../dataset/bh_test.csv')

    preprocessor = matchzoo.preprocessors.DSSMPreprocessor(with_word_hashing=False)
    preprocessor._units = [ Tokenize(),
                            matchzoo.preprocessors.units.Lowercase(),
                            matchzoo.preprocessors.units.PuncRemoval(),
                            matchzoo.preprocessors.units.NgramLetter(ngram = 2),]

    train_pack_processed = preprocessor.fit_transform(train_raw)
    test_pack_processed = preprocessor.transform(test_raw)

    model = matchzoo.models.DSSM()
    model.params['input_shapes'] = preprocessor.context['input_shapes']
    model.params['task'] = ranking_task
    model.params['mlp_num_layers'] = 3
    model.params['mlp_num_units'] = 300
    model.params['mlp_num_fan_out'] = 128
    model.params['mlp_activation_func'] = 'relu'
    model.guess_and_fill_missing_params()
    model.build()
    model.compile()

    term_index = preprocessor.context['vocab_unit'].state['term_index']
    hashing_unit = matchzoo.preprocessors.units.WordHashing(term_index)
    pred_x, pred_y = test_pack_processed[:].unpack()
    evaluate = matchzoo.callbacks.EvaluateAllMetrics(model, x=pred_x, y=pred_y, batch_size=len(pred_x))
    train_generator = matchzoo.DataGenerator(
    train_pack_processed,
    num_dup=1,
    num_neg=4,
    batch_size=10,
    mode='pair',
    callbacks=[
        matchzoo.data_generator.callbacks.LambdaCallback(
            on_batch_data_pack=lambda dp: dp.apply_on_text(
                hashing_unit.transform, inplace=True, verbose=0)
        )
    ]
)

    history = model.fit_generator(train_generator, epochs=200, callbacks=[evaluate], workers=6,
                                  use_multiprocessing=False)

predict() 要求第一个参数是,dict,第二个是ndarray,我是按这种规则进行传参,并且数据集格式和wikiQA一致,请问 在训练集更改hash方式后,测试集是否需要相应改动?

uduse commented 5 years ago

@wangshansong1 是的,你的测试集也需要用同一个 hashing_unit 做一次处理。如果的你的测试集小的话,可以直接预处理好再喂给 EvaluateAllMetrics,否则,你就需要自己改一个版本的 EvaluateAllMetrics,在其中分批处理再整合结果。

wangshansong1 commented 5 years ago

好的,这个问题已经解决了:重新实例化一个preprocessor1 = DSSMPreprocessor,with_word_hashing设置为True。并将原来的preprocessor.context['vocab_unit']赋给新的的preprocessor1,将测试集wordhashing。

ps:方法比较笨,仍会占用较多内存。由于是测试集,占用量还可以接受。