oracle / tribuo

Tribuo - A Java machine learning library
https://tribuo.org
Apache License 2.0
1.27k stars 175 forks source link

Calculating the AUC and ROC #370

Open Mohammed-Ryiad-Eiadeh opened 4 months ago

Mohammed-Ryiad-Eiadeh commented 4 months ago

Dear Tribuo developers,

I am trying to get the TRP and the FTP in order to calculate the AUC and plot the ROC curve. But the results sometimes unreasonable since i have high accuracy, yet low AUC. Furthermore, here is my code:

    // use KNN classifier
    // noinspection DuplicatedCode
    var KnnTrainer =  new KNNTrainer<>(3,
            new L1Distance(),
            Runtime.getRuntime().availableProcessors(),
            new VotingCombiner(),
            KNNModel.Backend.THREADPOOL,
            NeighboursQueryFactoryType.BRUTE_FORCE);

    // disply the model provenance
    var modelProvenance = KnnTrainer.getProvenance();
    System.out.println("The model provenance is \n" + ProvenanceUtil.formattedProvenanceString(modelProvenance));

    // use crossvalidation
    // noinspection DuplicatedCode
    var crossValidation = new CrossValidation<>(KnnTrainer, dataSet, new LabelEvaluator(), 10, Trainer.DEFAULT_SEED);

    // get outputs
    // noinspection DuplicatedCode
    var avgAcc = 0d;
    var sensitivity = 0d;
    var specificity = 0d;
    var macroAveragedF1 = 0d;
    var precision = 0d;
    var recall = 0d;
    var avgTP = new double[crossValidation.getK()];
    var avgFP = new double[crossValidation.getK()];
    var counter = 0;
    var sTrain = System.currentTimeMillis();
    for (var result: crossValidation.evaluate()) {
        avgAcc += result.getA().accuracy();
        sensitivity += result.getA().tp() / (result.getA().tp() + result.getA().fn());
        specificity += result.getA().tn() / (result.getA().tn() + result.getA().fp());
        macroAveragedF1 += result.getA().macroAveragedF1();
        precision += result.getA().tp() / (result.getA().tp() + result.getA().fp());
        recall += result.getA().tp() / (result.getA().tp() + result.getA().fn());
       avgTP[counter] = result.getA().tp() / (result.getA().tp() + result.getA().fn());
        avgFP[counter] = 1 - (result.getA().tn() / (result.getA().tn() + result.getA().fp()));
        counter++;
    }

    // noinspection DuplicatedCode
    var eTrain = System.currentTimeMillis();

    /*System.out.printf("The FS duration time is : %s\nThe number of selected features is : %d\nThe feature names are : %s\n",
            Util.formatDuration(sDate, eDate), SFS.featureNames().size(), SFS.featureNames());*/

    for (var stuff : List.of("The Training_Testing duration time is : " + Util.formatDuration(sTrain, eTrain),
            "The average accuracy is : " + (avgAcc / crossValidation.getK()),
            "The average sensitivity is : " + (sensitivity / crossValidation.getK()),
            "The average macroAveragedF1 is : " + (macroAveragedF1 / crossValidation.getK()),
            "The average precision is : " + (precision / crossValidation.getK()),
            "The average recall is : " + (recall / crossValidation.getK()))) {
        System.out.println(stuff);
    }

    AucCalculator aucCalculator = new AucCalculator(avgTP, avgFP);
    System.out.println("The AUC is : " + aucCalculator.getAUC());

    // Display the ROC curve chart and save it
    System.out.println(Arrays.toString(avgTP));
    System.out.println(Arrays.toString(avgFP));
Craigacp commented 4 months ago

I'm not sure how you're plotting the ROC curve when you need a threshold to sweep through to change the point at which a label is predicted. Tribuo already supports AUCROC for classifiers which produce probabilities, but KNNTrainer doesn't.

Mohammed-Ryiad-Eiadeh commented 4 months ago

Well I just calculate the FPR and TPR after each fold and use them to plot my ROC curve and I pass them to AUCCalculator to get the AUC value which is done by the trapezoidal rule. please if this is not correct tell me to change it.

Craigacp commented 4 months ago

That won't give you an appropriate ROC curve as it's not on the same data and doesn't represent how changing the classification threshold would change the false positive & true positive rate.

Mohammed-Ryiad-Eiadeh commented 4 months ago

Thanks for that, can you give some suggestions here.

On Wed, May 22, 2024 at 2:11 PM Adam Pocock @.***> wrote:

That won't give you an appropriate ROC curve as it's not on the same data and doesn't represent how changing the classification threshold would change the false positive & true positive rate.

— Reply to this email directly, view it on GitHub https://github.com/oracle/tribuo/issues/370#issuecomment-2125454003, or unsubscribe https://github.com/notifications/unsubscribe-auth/AWGLSQ5CAG2SX2QE45DIVG3ZDTNWTAVCNFSM6AAAAABICPAJXKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCMRVGQ2TIMBQGM . You are receiving this because you authored the thread.Message ID: @.***>

Craigacp commented 4 months ago

You'll need to use a model which supports generating probabilities, and then you can use the methods on LabelEvaluation to compute the AUC, or LabelEvaluationUtil to compute the ROC curve itself - https://tribuo.org/learn/4.3/javadoc/org/tribuo/classification/evaluation/LabelEvaluationUtil.html.

Mohammed-Ryiad-Eiadeh commented 4 months ago

Dear Adam,

I need your help here. After getting the FPR nad TPR and Threshold like this:

FPR: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.007352941176470588, 0.014705882352941176, 0.022058823529411766, 0.029411764705882353, 0.03676470588235294, 0.04411764705882353, 0.051470588235294115, 0.058823529411764705, 0.0661764705882353, 0.07352941176470588, 0.08088235294117647, 0.08823529411764706, 0.09558823529411764, 0.10294117647058823, 0.11764705882352941, 1.0]

TPR: [0.0, 0.9456521739130435, 0.9565217391304348, 0.967391304347826, 0.9782608695652174, 0.9891304347826086, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

Threshold: [Infinity, 0.9999999999999585, 0.9999999999900508, 0.9999999999532103, 0.9999999993470261, 0.999999986279678, 0.999999082458228, 0.4672622616731544, 0.010474966475835753, 7.848048691768931E-4, 4.464634619108306E-4, 1.8357524563945583E-4, 7.697946270832445E-5, 1.5905677137563365E-5, 1.258714255136621E-7, 4.6428762209717544E-8, 1.3855807195706487E-8, 8.900923141832403E-9, 7.567814072544735E-9, 7.443858692758792E-9, 2.8687081675940852E-9, 1.3147063911388807E-12, 1.7464080806775956E-82]

when plotting FPR and TPR, how to get the number of correctly classified points corresponding the positive label to get the ROC ?!

Craigacp commented 4 months ago

That information isn't stored in the ROC class, the number of correctly classified points is stored in your LabelEvaluation.