wmkouw / libTLDA

Library of transfer learners and domain-adaptive classifiers.
MIT License
90 stars 25 forks source link

Logistic Discriminator Cross Val Predict Outputs Classes (Not Probabilities) #19

Open kharrigian opened 4 years ago

kharrigian commented 4 years ago

The cross_val_predict method used for domain discrimination in the instance weighting class outputs classes (not probabilities) by default. It appears the intended behavior requires passing method="predict_proba" to the cross_val_predict method. The output will be a 2d-array of class probabilities when using this argument and will thus require column-indexing to extract the appropriate probabilities.

I'm not sure if this was always the case with sklearn or if it was updated after the original domain discrimination implementation. Either way, this additional argument should rectify the issue.

rchew commented 3 years ago

@wmkouw Are you accepting pull requests? If so, I think this could be a quick change from:

https://github.com/wmkouw/libTLDA/blob/0c66ec2327d191b88fca0803a7c74b0bb05afd42/libtlda/iw.py#L256

To: preds = cross_val_predict(lr, XZ, y[:, 0], cv=5, method="predict_proba")[:,1]

wmkouw commented 3 years ago

Hi @rchew, yeah sure. I had big plans for this toolbox, but have changed my research focus in the last 2/3 years and I don't think I have time for further development.