NTMC-Community / MatchZoo

Facilitating the design, comparison and sharing of deep text matching models.
Apache License 2.0
3.82k stars 898 forks source link

Problems in using DSSMPreprocessor and RankHingeLoss #763

Closed aizest closed 5 years ago

aizest commented 5 years ago

Describe the Question

When I use DSSMPreprocessor and RankHingeLoss to train the DSSM model on my data, I always get "KeyError" in the model.fit_generator() function.

Describe your attempts

MY MZ version is 2.1.0.

Here are my code snippet and errors, please advise. Thanks a lot.

---------Code-----------

infile = open(TRAIN_DATA_PACK_PATH, 'rb')
train_raw_dp = pickle.load(infile)
infile.close()

infile = open(VAL_DATA_PACK_PATH, 'rb')
val_raw_dp = pickle.load(infile)
infile.close()

preprocessor = mz.preprocessors.DSSMPreprocessor()

train_processed_dp = preprocessor.fit_transform(train_raw_dp)
val_processed_dp = preprocessor.transform(val_raw_dp)

model = mz.models.DSSM()

task = mz.tasks.Ranking(loss=mz.losses.RankHingeLoss(num_neg=4))
task.metrics = [
    mz.metrics.AveragePrecision(),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),
    mz.metrics.NormalizedDiscountedCumulativeGain(k=10),
    mz.metrics.MeanAveragePrecision()
]

model.params['input_shapes'] = preprocessor.context['input_shapes']
model.params['task'] = task
model.params['mlp_num_layers'] = 3
model.params['mlp_num_units'] = 300
model.params['mlp_num_fan_out'] = 128
model.params['mlp_activation_func'] = 'relu'
model.guess_and_fill_missing_params()

model.build()
model.compile()

train_x, train_y = train_processed_dp.unpack()
val_x, val_y = val_processed_dp.unpack()

train_generator = mz.DataGenerator(train_processed_dp, num_dup=1, num_neg=4, batch_size=64, shuffle=True)

evaluate = mz.callbacks.EvaluateAllMetrics(model, x=val_x, y=val_y, batch_size=32)

history = model.fit_generator(train_generator, epochs=50, callbacks=[evaluate], workers=5, use_multiprocessing=False)

MZ always gives me the following errors:

KeyError                                  Traceback (most recent call last)
<command-1933085710792270> in <module>()
      4 evaluate = mz.callbacks.EvaluateAllMetrics(model, x=val_x, y=val_y, batch_size=32)
      5 
----> 6 history = model.fit_generator(train_generator, epochs=50, callbacks=[evaluate], workers=5, use_multiprocessing=False)

/python/lib/python3.6/site-packages/matchzoo/engine/base_model.py in fit_generator(self, generator, epochs, verbose, **kwargs)
    274             generator=generator,
    275             epochs=epochs,
--> 276             verbose=verbose, **kwargs
    277         )
    278 

/python/lib/python3.6/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/python/lib/python3.6/site-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1416             use_multiprocessing=use_multiprocessing,
   1417             shuffle=shuffle,
-> 1418             initial_epoch=initial_epoch)
   1419 
   1420     @interfaces.legacy_generator_methods_support

/python/lib/python3.6/site-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    179             batch_index = 0
    180             while steps_done < steps_per_epoch:
--> 181                 generator_output = next(output_generator)
    182 
    183                 if not hasattr(generator_output, '__len__'):

/python/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    599         except Exception as e:
    600             self.stop()
--> 601             six.reraise(*sys.exc_info())
    602 
    603 

/python/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/python/lib/python3.6/site-packages/keras/utils/data_utils.py in get(self)
    593         try:
    594             while self.is_running():
--> 595                 inputs = self.queue.get(block=True).get()
    596                 self.queue.task_done()
    597                 if inputs is not None:

/python/lib/python3.6/multiprocessing/pool.py in get(self, timeout)
    642             return self._value
    643         else:
--> 644             raise self._value
    645 
    646     def _set(self, i, obj):

/python/lib/python3.6/multiprocessing/pool.py in worker(inqueue, outqueue, initializer, initargs, maxtasks, wrap_exception)
    117         job, i, func, args, kwds = task
    118         try:
--> 119             result = (True, func(*args, **kwds))
    120         except Exception as e:
    121             if wrap_exception and func is not _helper_reraises_exception:

/python/lib/python3.6/site-packages/keras/utils/data_utils.py in get_index(uid, i)
    399         The value at index `i`.
    400     """
--> 401     return _SHARED_SEQUENCES[uid][i]
    402 
    403 

/python/lib/python3.6/site-packages/matchzoo/data_generator/data_generator.py in __getitem__(self, item)
    130         else:
    131             indices = self._batch_indices[item]
--> 132         batch_data_pack = self._data_pack[indices]
    133         self._handle_callbacks_on_batch_data_pack(batch_data_pack)
    134         x, y = batch_data_pack.unpack()

/python/lib/python3.6/site-packages/matchzoo/data_pack/data_pack.py in __getitem__(self, index)
    165         """
    166         index = _convert_to_list_index(index, len(self))
--> 167         relation = self._relation.loc[index].reset_index(drop=True)
    168         left = self._left.loc[relation['id_left'].unique()]
    169         right = self._right.loc[relation['id_right'].unique()]

/python/lib/python3.6/site-packages/pandas/core/indexing.py in __getitem__(self, key)
   1476 
   1477             maybe_callable = com._apply_if_callable(key, self.obj)
-> 1478             return self._getitem_axis(maybe_callable, axis=axis)
   1479 
   1480     def _is_scalar_access(self, key):

/python/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_axis(self, key, axis)
   1899                     raise ValueError('Cannot index with multidimensional key')
   1900 
-> 1901                 return self._getitem_iterable(key, axis=axis)
   1902 
   1903             # nested tuple slicing

/python/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_iterable(self, key, axis)
   1141             if labels.is_unique and Index(keyarr).is_unique:
   1142                 indexer = ax.get_indexer_for(key)
-> 1143                 self._validate_read_indexer(key, indexer, axis)
   1144 
   1145                 d = {axis: [ax.reindex(keyarr)[0], indexer]}

/python/lib/python3.6/site-packages/pandas/core/indexing.py in _validate_read_indexer(self, key, indexer, axis)
   1204                 raise KeyError(
   1205                     u"None of [{key}] are in the [{axis}]".format(
-> 1206                         key=key, axis=self.obj._get_axis_name(axis)))
   1207 
   1208             # we skip the warning on Categorical/Interval

KeyError: 'None of [[37596, 21542, 15189, 32265, 5790, 99, 991, 35011, 20699, 14141, 12783, 28141, 35764, 3330, 2918, 16128, 27929, 32765, 14856, 22701, 14621, 37754, 782, 39755, 8899, 20247, 25519, 6095, 32722, 18522, 12709, 37101, 5060, 6124, 15453, 34180, 18410, 19402, 14495, 16828, 17782, 38966, 38632, 5121, 14618, 4447, 23929, 6401, 11525, 21241, 7821, 847, 25460, 17666, 36176, 8590, 34073, 9025, 8154, 30236, 21504, 31984, 33811, 24425]] are in the [index]'

The model runs well if I use the BasicPreprocessor and MSE loss (the default loss). But I want to try Hinge Loss because it's the loss function used by most of the MZ models. Your help will be highly appreciated!

uduse commented 5 years ago

I guess the bug's here: train_generator = mz.DataGenerator(train_processed_dp, num_dup=1, num_neg=4, batch_size=64, shuffle=True)

The problem is that you are using DataGenerator instead of PairDataGenerator (which is the correct thing to do since PairDataGenerator is deprecated), but forget setting the mode to "pair".

To fix it: train_generator = mz.DataGenerator(train_processed_dp, mode='pair', num_dup=1, num_neg=4, batch_size=64, shuffle=True)

Maybe we should raise a warning or something when user uses num_dup or num_neg without setting mode='pair'.

(Maybe you followed the tutorial, and when using PairDataGenerator on your own, you saw a warning and replaced it with DataGenerator?)

aizest commented 5 years ago

Thanks! The suggestion makes sense. It's very helpful, but is not enough for this problem.

I finally solved it by adding a "new_relation.dropna(inplace=True)" at the end of (but before return) the "_reorganize_pair_wise" function of "data_generator.py". For some reason, the operations inside the _reorganize_pair_wise function may produce some NA rows, which lead to the issues above. So, I think it will be more safe if we do "dropna" before returning the reorganized relations.

Thanks again for the quick response!

uduse commented 5 years ago

I have no idea why it produces na values since I don't have your data and I can't reproduce it.

aizest commented 5 years ago

Need more debug to find the reason, but adding a "new_relation.dropna(inplace=True)" did solve the problem. Hope this can help other users who face similar issues. Thanks again for paying attention to this ticket.

aizest commented 5 years ago

I think this ticket can be closed now