Closed aizest closed 5 years ago
I guess the bug's here:
train_generator = mz.DataGenerator(train_processed_dp, num_dup=1, num_neg=4, batch_size=64, shuffle=True)
The problem is that you are using DataGenerator
instead of PairDataGenerator
(which is the correct thing to do since PairDataGenerator
is deprecated), but forget setting the mode
to "pair"
.
To fix it:
train_generator = mz.DataGenerator(train_processed_dp, mode='pair', num_dup=1, num_neg=4, batch_size=64, shuffle=True)
Maybe we should raise a warning or something when user uses num_dup
or num_neg
without setting mode='pair'
.
(Maybe you followed the tutorial, and when using PairDataGenerator
on your own, you saw a warning and replaced it with DataGenerator
?)
Thanks! The suggestion makes sense. It's very helpful, but is not enough for this problem.
I finally solved it by adding a "new_relation.dropna(inplace=True)" at the end of (but before return) the "_reorganize_pair_wise" function of "data_generator.py". For some reason, the operations inside the _reorganize_pair_wise function may produce some NA rows, which lead to the issues above. So, I think it will be more safe if we do "dropna" before returning the reorganized relations.
Thanks again for the quick response!
I have no idea why it produces na values since I don't have your data and I can't reproduce it.
Need more debug to find the reason, but adding a "new_relation.dropna(inplace=True)" did solve the problem. Hope this can help other users who face similar issues. Thanks again for paying attention to this ticket.
I think this ticket can be closed now
Describe the Question
When I use DSSMPreprocessor and RankHingeLoss to train the DSSM model on my data, I always get "KeyError" in the model.fit_generator() function.
Describe your attempts
MY MZ version is 2.1.0.
Here are my code snippet and errors, please advise. Thanks a lot.
---------Code-----------
MZ always gives me the following errors:
The model runs well if I use the BasicPreprocessor and MSE loss (the default loss). But I want to try Hinge Loss because it's the loss function used by most of the MZ models. Your help will be highly appreciated!