accord-net / framework

Machine learning, computer vision, statistics and general scientific computing for .NET
http://accord-framework.net
GNU Lesser General Public License v2.1
4.49k stars 1.99k forks source link

GeneralConfusionMatrix's precision and recall numbers seem to be flipped #1187

Open yoonhwang opened 6 years ago

yoonhwang commented 6 years ago

What would you like to submit? (put an 'x' inside the bracket that applies)

Issue description

var cv = CrossValidation.Create<NaiveBayes<BinomialDistribution>, NaiveBayesLearning<BinomialDistribution>, double[], int>(

                k: 3, // number of folds

                learner: (p) => new NaiveBayesLearning<BinomialDistribution>(), // Naive Bayes Classifier

                loss: (actual, expected, p) => new ZeroOneLoss(expected).Loss(actual),

                fit: (teacher, x, y, w) => teacher.Learn(x, y, w),

                x: input, 

                y: output
            );

            // After the cross-validation object has been created,
            // we can call its .Learn method with the input and 
            // output data that will be partitioned into the folds:
            var result = cv.Learn(input, output);

            // We can grab some information about the problem:
            int numberOfSamples = result.NumberOfSamples;
            int numberOfInputs = result.NumberOfInputs;
            int numberOfOutputs = result.NumberOfOutputs;

            double trainingError = result.Training.Mean;
            double validationError = result.Validation.Mean;

            var expectedOut = result.Models[0].Model.Decide(input);
            var actualOut = targetVariables.ValuesAll.ToArray();
            var correct = 0;
            var correctClass1 = 0;
            for(int i = 0; i < expectedOut.Length; i++)
            {
                if(expectedOut[i] == actualOut[i])
                {
                    correct += 1;
                    if(expectedOut[i] > 0)
                    {
                        correctClass1 += 1;
                    }
                }
            }

            Console.WriteLine("{0}", ((float)correct / (float)actualOut.Length));
            Console.WriteLine("{0}", ((float)correctClass1 / (float)targetVariables.NumSum()));

            // If desired, compute an aggregate confusion matrix for the validation sets:
            GeneralConfusionMatrix gcm = result.ToConfusionMatrix(input, output);
            Console.WriteLine("");
            Console.Write("\t\tActual 0\t\tActual 1\n");
            for (int i = 0; i < gcm.Matrix.GetLength(0); i++)
            {
                Console.Write("Pred {0} :\t", i);
                for (int j = 0; j < gcm.Matrix.GetLength(1); j++)
                {
                    Console.Write(gcm.Matrix[i, j] + "\t\t\t");
                }
                Console.WriteLine();
            }
            double accuracy = gcm.Accuracy;
            double[] precision = gcm.Precision;
            double[] recall = gcm.Recall;

            Console.WriteLine("---- Using GeneralConfusionMatrix's Precision/Recall ----");
            Console.WriteLine("# samples: {0}, # inputs: {1}, # outputs: {2}", numberOfSamples, numberOfInputs, numberOfOutputs);
            //Console.WriteLine("training error: {0}", trainingError);
            //Console.WriteLine("validation error: {0}", validationError);
            Console.WriteLine("Precision: {0}, {1}", precision[0], precision[1]);
            Console.WriteLine("Recall: {0}, {1}", recall[0], recall[1]);
            //Console.WriteLine("Accuracy: {0}", accuracy);

            Console.WriteLine("---- Manually Calculating Precision & Recall Numbers");
            Console.WriteLine("Row total: {0}, {1}", gcm.RowTotals[0], gcm.RowTotals[1]);
            Console.WriteLine("Col total: {0}, {1}", gcm.ColumnTotals[0], gcm.ColumnTotals[1]);
            // True-Positive / (True-Positive + False-Positive)
            Console.WriteLine("Precision: {0}", ((float)gcm.Matrix[1, 1]) / (gcm.Matrix[1, 0] + gcm.Matrix[1, 1]));
            // True-Positive / (True-Positive + False-Negative)
            Console.WriteLine("Recall: {0}", ((float)gcm.Matrix[1, 1]) / (gcm.Matrix[1, 1] + gcm.Matrix[0, 1]));
olegtarasov commented 6 years ago

I can confirm this issue, we have to compute these metrics by hand because of this bug.

davidrapan commented 4 years ago

Yes, Precision is in code computed as TP / ColumnSum(TP + FN) instead of TP / RowSum(TP + FP) and vice versa with Recall. (I personally solved the problem with rebuild)