NTMC-Community / MatchZoo

Facilitating the design, comparison and sharing of deep text matching models.
Apache License 2.0
3.84k stars 898 forks source link

Large datasets freeze and crash during DSSM preprocessing #481

Closed MichaelStarkey closed 5 years ago

MichaelStarkey commented 5 years ago

I am trying to replicate the wikiQA DSSM tutorial on a much larger dataset (100k queries, 1m answers). The preprocessing seems to freeze and crash during the text_right stage. This issue is occurring on both 2.0 and 2.0-dev Is this a memory issue?

I am using a google cloud instance with 32Gb RAM and 2 vCPUs. Current code below:

import keras
import pandas as pd
import numpy as np
import matchzoo as mz

def read_data(path):
    def scan_file():
        with open(path) as in_file:
            for l in in_file:
                yield l.strip().split('\t')
    if include_label:
        return [(qid, did, q, d, float(label)) for qid, did, q, d, label in scan_file()]

msm_dev = read_data('../data/msm_dev.mz')
train = msm_dev[:897558]
predict =  [(qid, did, q, d) for qid, did, q, d, _ in msm_dev[897558:]]
train_pack = mz.pack(train)
predict_pack = mz.pack(predict)

preprocessor = mz.engine.load_preprocessor('../outputs/msm_dssm_prep')
train_pack_processed = preprocessor.transform(train_pack)

Output:

Processing `text_left` with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit: 100%|██████████| 90008/90008 [00:31<00:00, 2861.30it/s]
Processing `text_right` with chain_transform of TokenizeUnit => LowercaseUnit => PuncRemovalUnit => StopRemovalUnit => NgramLetterUnit => WordHashingUnit:   5%|▍         | 41838/897558 [00:51<18:06, 787.74it/s]
uduse commented 5 years ago

Yes it is. WordHashingUnit eats up a lot of memory. On 2.0-dev these's a UnitDynamicDataGenerator that can work around this problem by skipping the word hashing step in the preprocessing, initializing UnitDynamicDataGenerator with the WordHashingUnit, feeding the UnitDynamicDataGenerator with the datapack. This does forgo the ability to use model.fit, though, since you have to use model.fit_generator. Since this problem seems quite common, more support regarding this will be added soon.

MichaelStarkey commented 5 years ago

Im unsure how to proceed with the Dynamic Generator. Am I still required to use the DSSM preprocessor to go from the training data to triletters? It seems the WordHashingUnit is written in to the DSSMPreprocessor.fit method. I attempted the following code but it did not work, so I am missing something.

word_hashing_unit = mz.processor_units.WordHashingUnit
ngram_unit = mz.processor_units.NgramLetterUnit
dynamic_train_hashing_generator = mz.data_generator.UnitDynamicDataGenerator(data_pack=train_pack, unit=word_hashing_unit)
dynamic_train_ngram_generator = mz.data_generator.UnitDynamicDataGenerator(data_pack=train_pack, unit=ngram_unit)

ranking_task = mz.tasks.Ranking()
ranking_task.metrics = [
    'mae', 'map', 'precision',
    mz.metrics.Precision(k=3),
    mz.metrics.DiscountedCumulativeGain(k=1),
    mz.metrics.DiscountedCumulativeGain(k=3),
    mz.metrics.DiscountedCumulativeGain(k=5),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=1),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=5)
]
model = mz.models.DSSMModel()
input_shapes = preprocessor.context['input_shapes']
model.params['input_shapes'] = input_shapes
model.params['task'] = ranking_task
model.guess_and_fill_missing_params()
model.build()
model.compile()
model.fit_generator(dynamic_train_hashing_generator , epochs=10)

With output:

StopIteration: transform() missing 1 required positional argument: 'tri_letters'

Additionally, I had to amend the matchzoo/data_generator/__init__.py file to access the UnitDynamicDataGenerator class.

uduse commented 5 years ago

@MichaelStarkey You have to create unit instances instead of using them as classes. Anyways, since MatchZoo needs this feature sooner or later, I just implemented it for your early access. Check out (pun intended) the latest 2.0-dev branch and see sandbox.ipynb for usage.

MichaelStarkey commented 5 years ago

That seems to have done it, thank you!

uduse commented 5 years ago

DSSM and CDSSM now have the with_word_hashing argument.

prpr = DSSMPreprocessor(..., with_word_hashing=False)
prpr.fit(..)
hashing_unit = prpr.context['vocab_unit'].state['term_index']

# in 2.0
gen = DynamicDataGenerator(hashing_unit.transform, ...)

# in 2.1
gen = DataGenerator(..., callbacks=[
    mz.data_generator.callbacks.LambdaCallback(
        on_batch_data_pack=lambda x: x.apply_on_text(
            hashing_unit.transform, inplace=True
    ))
])