salesforce / TransmogrifAI

TransmogrifAI (pronounced trăns-mŏgˈrə-fī) is an AutoML library for building modular, reusable, strongly typed machine learning workflows on Apache Spark with minimal hand-tuning
https://transmogrif.ai
BSD 3-Clause "New" or "Revised" License
2.24k stars 392 forks source link

Add thresholded confusion matrix elements to binary classification metrics #492

Closed Jauntbox closed 4 years ago

Jauntbox commented 4 years ago

Related issues N/A

Describe the proposed solution Spark's threshold metrics only compute precision, recall, and false positive rate. It would be helpful to also have the confusion matrix components (TP, FP, TN, FN). Fortunately, the three quantities already calculated along with the equation TP + FP + TN + FN = Count provide a system of 4 independent equations in 4 unknowns, so the confusion matrix components can be calculated with closed-form formulae.

It turns out these equations become underdetermined whenever TP = 0 since both the precision and recall equations become 0 = 0. The actual calculation needed here (confusion matrices by threshold) is sitting in Spark's code, but not exposed (yet). Even worse, it's private instead of package private, so we can't get around it with an implicit class extension. For the moment, I'm doing a gross copy/paste to expose that info to our evaluators. The longer-term solution is to make a PR to Spark to expose this, so we can get rid of RichBinaryClassificationMetrics.

Describe alternatives you've considered This is now plan B!

Additional context A separate PR will add thresholded, discretized metrics for regression as well

codecov[bot] commented 4 years ago

Codecov Report

Merging #492 into master will increase coverage by 6.88%. The diff coverage is 81.25%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #492      +/-   ##
==========================================
+ Coverage   80.16%   87.04%   +6.88%     
==========================================
  Files         345      346       +1     
  Lines       11741    11782      +41     
  Branches      392      385       -7     
==========================================
+ Hits         9412    10256     +844     
+ Misses       2329     1526     -803     
Impacted Files Coverage Δ
...c/main/scala/com/salesforce/op/ModelInsights.scala 91.88% <ø> (+18.50%) :arrow_up:
...op/evaluators/OpMultiClassificationEvaluator.scala 94.73% <50.00%> (+1.31%) :arrow_up:
...p/evaluators/OpBinaryClassificationEvaluator.scala 82.60% <63.63%> (+0.10%) :arrow_up:
...b/evaluation/RichBinaryClassificationMetrics.scala 88.57% <88.57%> (ø)
...esforce/op/features/types/FeatureTypeFactory.scala 99.13% <0.00%> (+0.86%) :arrow_up:
...la/com/salesforce/op/features/FeatureBuilder.scala 35.17% <0.00%> (+1.37%) :arrow_up:
...la/com/salesforce/op/stages/OpPipelineStages.scala 63.88% <0.00%> (+1.38%) :arrow_up:
...n/scala/com/salesforce/op/dsl/RichMapFeature.scala 67.64% <0.00%> (+1.47%) :arrow_up:
... and 66 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update d0b1038...1b49a4b. Read the comment docs.

nicodv commented 4 years ago

LGTM