ageron / handson-ml

⛔️ DEPRECATED – See https://github.com/ageron/handson-ml3 instead.
Apache License 2.0
25.2k stars 12.91k forks source link

Chapter 3: cross_val_predict() #435

Open perezzini opened 5 years ago

perezzini commented 5 years ago

This chapter uses cross_val_predict() to make cross validation predictions, and then use them to compute some metrics like confusion matrix, recall, precision, and more.

But Sklearn documentation (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.cross_val_predict.html) says the following: "It is not appropriate to pass these predictions into an evaluation metric. Use cross_validate to measure generalization error."

So, what's the point here? Can we use cross_val_predict() results to compute metrics? Or am I missing something?

Thanks in advance!

ageron commented 5 years ago

Hi @perezzini ,

That's a good point, thanks, I'll add a note to clarify this in the chapter.

An example will make things clearer. Suppose there are 10 instances in the dataset and we use 2 folds of 5 instances each. Both cross_val_predict() and cross_val_score() will train two models, one on fold 1 and the other on fold 2. Both functions will use the model 1 to make predictions on fold 2, and model 2 to make predictions on fold 1. This gives two groups of predictions, P1 and P2, with 5 predictions each, for example:

P1=[9, 5, 4, 6, 2] and P2=[4, 5, 4, 3, 4].

Now suppose the labels are L=[9, 5, 4, 6, 2, 3, 6, 3, 2, 3]. You can see that the predictions P1 are perfect, while the predictions P2 are off by 1. Apparently, the first model was perfect, while the second was not.

Now cross_val_score() would compute the score independently for each group of predictions. Suppose we are using the MSE, this gives us: MSE(P1)=0 and MSE(P2)=(1^2+1^2+1^2+1^2)/4=1. This is what cross_val_score() would report: [0, 1]. You can then compute the mean and get the final evaluation: 0.5.

Now let's look at cross_val_predict(). It would just return the concatenation of P1 and P2, losing the information about which prediction came from which model: [9, 5, 4, 6, 2, 4, 5, 4, 3, 4]. You can then compute the MSE, which is: (0^2+0^2+0^2+0^2+0^2+1^2+1^2+1^2+1^2+1^2)/10 = 0.5. As you can see, you get the same final evaluation in this case. The benefit of cross_val_predict() is that you have access to the predictions, so you can plot them, analyze them, or use them to train blending models (see chapter 7). However, it hides the fact that some models were perfect while others were much worse.

If the metric is a simple mean over instance errors (e.g., mean squared error, or mean absolute error, or mean cross-entropy, etc.), then the final scores should always be the same. However, it's not always that simple. For example, consider the precision metric for a classification task (chapter 3). Suppose you are training a binary classifier, and the predictions are: P1=[1, 1, 1, 0, 1] and P2=[0, 0, 1, 1, 0]. Now suppose the labels are L=[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]. The precision is the number of true positives divided by the number of positive predictions. So the precision over P1 is 3/4=75%, while the precision over P2 is 1/2=50%. So cross_val_score() will return [0.75, 0.5]. If you compute the mean, you get: 62.5%. Now cross_val_predict() will just return the concatenated predictions P=[1, 1, 1, 0, 1, 0, 0, 1, 1, 0]. If you compute the precision, you get: 4/6=66.66%. That's a different result!

Should you trust cross_val_score() or cross_val_predict() then? Well, cross_val_score() is definitely preferable in this case, as it gives you more details about the variability of the metric depending on the trained model, while cross_val_predict() just fuses all the predictions from all folds, although they were made with different models.

Hope this helps.

perezzini commented 5 years ago

Hi @ageron.

First of all, thanks for your great explanation!

My main problem, here, with Chapter 3 is the way it explains how to adjust the discrimination threshold in a model: it adjusts the threshold in the training set so we can set aside the test set for final evaluation using the selected threshold. To select the discrimination threshold in training set we need predictions, so the book uses cross_val_predict() in several occasions: for computing confusion matrix, precision/recall curves, ROC curve, etc.

Is it still a good option to adjust the threshold in training set using cross_val_predict()? What could go wrong? Or it has to be adjusted in final evaluation using the test set (which I think this is really wrong, but I saw some people doing this...)?

Thanks, and sorry if I am missing something (which is likely...)!

ageron commented 5 years ago

Hi @perezzini ,

Good question! If you have a validation set, you should use it to select the threshold. If you don't, then you can generally use cross_val_predict() to generate predictions on the training set and select the threshold based on them (as I did in Chapter 3), but I agree that it's not perfect, since the predictions are from a mix of distinct models. In practice, I think it generally works fine, though.

If you want a cleaner option, you could split the dataset in K folds, and for each fold train the model on the rest, generate predictions for the remaining fold and plot the precision/recall or ROC curve based on that. This will give you K curves, and you can choose more wisely. In some cases you may end up selecting a different (better) threshold than with the method I use.

My feeling is that using the test set to select the threshold would not be a great idea. It's probably not a big deal (we're just learning a single number), but I would prefer to use cross_val_predict() or a validation set, to avoid touching the test set.

In any case, this threshold can always easily be adjusted once the model is in production. For example, if you notice that the model in production is letting a bit too many false negatives go through (i.e., the recall is a bit too low), then you might want to lower the threshold slightly. Conversely, if you get plenty of false positives (i.e., precision too low), you could raise the threshold a bit.

Hope this helps.

perezzini commented 5 years ago

Thanks for replying @ageron !

Given that I have a REALLY small dataset (of text documents, in this case), I'd consider using cross_val_predict() for adjusting the discrimination threshold during cross validation.

Thanks, again; and congratulations for the book! Would like to stay in touch for future questions!

Cheers.