geffy / tffm

TensorFlow implementation of an arbitrary order Factorization Machine
MIT License
780 stars 176 forks source link

Custom loss function #36

Closed peterewills closed 6 years ago

peterewills commented 6 years ago

Added support for custom loss functions in TFFMClassifier. Implemented cross-entropy in the utils module. Also added a class_weight parameter, which automatically uses weighted cross-entropy. One can use class_weight = "balanced" to use the heuristic pos_weight = n_negative / n_positive, where n_positive is the number of positive samples in the training labels.

Note that custom loss functions are disallowed for TFFMRegressor at the moment. Adding them would be trivial, but I don't know of a use case where you'd want something other than MSE, so I left it alone.

I will add a demo of this functionality to the example notebook sometime soon.

geffy commented 6 years ago

Hi @peterewills ! Thanks for contributing! I've made several comments about implementation. Feel free to discuss if you are not agree with them. It will be great if you fix them. If not -- I believe I can do it by myself at the end of the week.

peterewills commented 6 years ago

Made some comments. Let me know what you think, and I'll work on fixing these things end of this week or early next week.

peterewills commented 6 years ago

I was running into problems getting the "balanced" weight to be calculated over the whole training set rather than the mini-batch, so I just removed that option, forcing the user to set the weight directly. There is now only one additional loss function, utils.loss_xentropy which can be weighted with a keyword argument pos_weight. Moved the assignment of the loss function to the initializer for TFFMClassifier, rather than TFFMCore. Also, an error is raised if the user tries to provide a custom loss function to TFFMRegressor, making it more explicit that only MSE error can be used.

If you'd still like to have pos_weight as an parameter of classifier.fit(), then I'd be fine with it, but I think I'll leave it up to you.

What has me avoiding adding a balanced option is the complication of evaluating the positive weight. Would we have to put a placeholder in the graph, then assign it a value when fit() is called?

geffy commented 6 years ago

I re-write some parts and add 'balanced' option. Can you, please, check it? Does it solve your original problem with sample weights?

peterewills commented 6 years ago

It seems like you've written this to be very general, so that arbitrary weights can be provided for each sample. This is more general than I've ever needed to use, but I suppose it's fine. I don't like the disallowing of arbitrary loss functions for the TFFMClassifier class; I thought that was the whole point. Also, I think we should allow the user to provide the positive weight at the time of classifier construction if desired, a la scikit-learn. I'll implement these small changes when I get a chance.

Also not sure why loss_xentropy got removed.

geffy commented 6 years ago

Cross-entropy loss for binary classification is exactly a logloss. Since TFFM supports only binary classification, there is no need for separate cross-entropy loss. This is the reason why I removed it.

Sample-wise weights are more general than class-wise in a sense that for logloss we can model class-wise weights via sample-wise. This is the reason why I want to implement it in a more general way since it will cover both cases.

I agree that class_weights (or pos_weight) can be moved into the constructor. I'm ok with that.

About custom loss. The point was to create TFFMBase which can take any loss function. And TFFMClassifier and TFFMRegressor are just TFFMBase with the bundled loss (logloss and mse respectively). For now, right way to use some custom loss is to define your own class (similar to TFFMClassifier), define your loss and a way to make predictions (should it be normalized, sigmoid-transform and so on.)

peterewills commented 6 years ago

Okay, that all makes sense. Was being thick about log-loss vs cross-entropy, I see it now. I'll add those keyword arguments to the constructor, and then add a demo of using weighted loss and custom loss into the examples notebook. Probably sometime next week. Thanks for all the input!

geffy commented 6 years ago

Thank you for your effort! Finally merged :)