automl / auto-sklearn

Automated Machine Learning with scikit-learn
https://automl.github.io/auto-sklearn
BSD 3-Clause "New" or "Revised" License
7.66k stars 1.28k forks source link

Softmax in predict_proba() #1483

Open amaletzk opened 2 years ago

amaletzk commented 2 years ago

Hi,

I'm new to auto-sklearn and quite impressed by its performance and ease of use - thanks for all your effort!

However, recently I experienced strange behavior that might be related to a bug. In a multlabel problem, one of the fitted pipelines returned class probabilities that always summed to 1 - which is the expected behavior in binary- and multiclass problems, but not necessarily in multilabel problems. After a closer look at the sources it seems this behavior is caused by the softmax() function being applied to the output of the underlying sklearn estimator in the predict_proba() method of some classifiers. These include LDA, QDA, LibLinear_SVC, PassiveAggressive and SGD.

To my understanding, softmax() should not be called in multilabel problems. It might also be problematic in multiclass problems if the underlying sklearn estimator already returns a probability distribution, which OneVsRestClassifier (used in LDA, QDA, LibLinear_SVC, PassiveAggressive) does.

I'm using auto-sklearn 0.14.6 and Python 3.9 on Ubuntu 20.04.4.

eddiebergman commented 2 years ago

Hi @amaletzk,

Thanks for reporting this, sorry for the delayed response, we currently have some other additions coming soon which have been taking quite a bit of time.

@aseemk98 This would be a helpful issue to fix and seems like something manageable. I assume for some of the classifiers, we carelessly just apply a softmax to the output of predict_proba which as @amaletzk rightly points out, does not make sense for multi-label classification, where each classes probability is independent of each other.

Fixing this likely requires first a test over all classifiers on a multi-label problem which we do those tests here.

What this test should actually test for is a bit difficult to define. We want to make sure that we don't accidentally have this softmax for multi-label again while changing things. The obvious test is check if for each probability vector for an instance, the sum of them do not add up to 1 when multi-label. This may still happen though, for example if the classifier truely believes the 1st class is .4 likely and the 2nd class is .6 likely. If you have a better idea or @amaletzk if you have an idea how to test for this so we don't accidentally have this again, then please feel free to suggest. You may want to copy the structure of this test but create a new one that tests predict_proba things explicitly.

Anywho, one approach to fix this might be to first check that the softmax is even required for some of the models, i.e. I'm not sure why predict_proba of LDA has a softmax around it. For any that use softmax(decision_function(X)), there's probably a reason and your best bet is to look at the sklearn docs and try figure out the best approach from there.

tldr; start with a test to confirm this behaviour, work backwards from there and I guess there will be points where decisions have to be made how to best handle them but we can discuss this in a PR :)

amaletzk commented 2 years ago

Hi @eddiebergman,

thanks for your response. I agree that softmax() should not be applied at all in LDA, QDA, LibLinear_SVC and PassiveAggressive, since the underlying OneVsRestClassifier already returns a probability distribution for multiclass problems and also handles multilabel problems correctly, as far as I can tell. SGD, which uses decision_function(), might have to be treated differently, though, just as you pointed out.

Concerning the tests, I'm afraid I'm of little help here ... the only thing that comes to my mind is checking for a specific, pre-computed output for every single instance.

aseemk98 commented 2 years ago

Hi @eddiebergman,

Yes, I'll take a look into this and try to figure it out. Thanks for referring this to me. I'll add a PR soon :)

aseemk98 commented 1 year ago

Hi @eddiebergman , Sorry for the year-long break on this. Is this issue still persistent?