-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add confusion matrix metrics #530
Conversation
Thanks for the contribution! Unfortunately we can't verify the commit author(s): FeiFei Jiang <f***@f***.i***.s***.com>. One possible solution is to add that email to your GitHub account. Alternatively you can change your commits to another email and force push the change. After getting your commits associated with your GitHub account, sign the Salesforce.com Contributor License Agreement and this Pull Request will be revalidated. |
Codecov Report
@@ Coverage Diff @@
## master #530 +/- ##
===========================================
+ Coverage 0 86.77% +86.77%
===========================================
Files 0 347 +347
Lines 0 12018 +12018
Branches 0 621 +621
===========================================
+ Hits 0 10429 +10429
- Misses 0 1589 +1589
Continue to review full report at Codecov.
|
@@ -94,6 +98,25 @@ private[op] class OpMultiClassificationEvaluator | |||
|
|||
def setThresholds(v: Array[Double]): this.type = set(thresholds, v) | |||
|
|||
def setTopKCm(v: Int): this.type = set(topKCm, v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should call these something different since we already have confusion between topK
and topN
in the multiclass metrics. How about a more self-documenting name like confMatrixNumClasses
for topKCm
and confMatrixMinSupport
for topNCm
?
*/ | ||
def findConfidenceThresholdBin(probability: Double, sortedThresholds: Seq[Double]): Double = { | ||
|
||
if (probability < sortedThresholds(0)) {0.0} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We shouldn't need to implement our own binary search here. If you cast sortedThresholds
to an indexedSeq
, then you can use scala.collection.Searching
here: https://github.com/scala/scala/blob/v2.11.0-M3/src/library/scala/collection/Searching.scala#L61-L99
It can find both exact matches and insertion points for elements not in the thresholds.
( | ||
topK: Int, | ||
thresholds: Seq[Double], | ||
confusionMatrixTopKByThreshold: Seq[Seq[Seq[Long]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should try and keep this close to the way Spark treats the multiclass confusion matrices, if possible. They use their Matrix
type for the actual confusion matrix, see eg. https://github.com/apache/spark/blob/v2.4.0/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala#L67. I'd suggest a Seq[Matrix]
type for the confusion matrices by threshold.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried returning a Matrix type as the confusion matrix, but it's causing Json deserialization error in both MultiClassificationModelSelectorTest
and ModelSelectorSummaryTest
(error msg: Caused by: com.fasterxml.jackson.databind.JsonMappingException: Can not construct instance of org.apache.spark.mllib.linalg.Matrix, problem: abstract types either need to be mapped to concrete types, have custom deserializer, or be instantiated with additional type information). Adding a Json serialization annotation doesn't help. I think it needs to be a concrete type in order to be extracted correctly.
In the updated PR, I keep the returned type as an Array. @Jauntbox
@@ -112,7 +135,11 @@ private[op] class OpMultiClassificationEvaluator | |||
log.warn("The dataset is empty. Returning empty metrics.") | |||
MultiClassificationMetrics(0.0, 0.0, 0.0, 0.0, | |||
MulticlassThresholdMetrics(Seq.empty, Seq.empty, Map.empty, Map.empty, Map.empty), | |||
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty)) | |||
MultiClassificationMetricsTopK(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty), | |||
MulticlassConfusionMatrixTopKByThreshold($(topKCm), $(thresholds), Seq.empty), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure we want to use the same set of thresholds for the topK confusion matrix as we do for the existing threshold metrics since they're just 3 * (topN = 3) sets of counts per threshold, while this may end up being somewhere in the 10 * 10 - 30 * 30 range for each threshold.
I think we should have a separate set of thresholds for the confusion matrices, and be pretty sparing with them.
@@ -133,20 +169,176 @@ private[op] class OpMultiClassificationEvaluator | |||
topKs = $(topKs) | |||
) | |||
|
|||
val rddCm = dataUse.select(col(labelColName), col(predictionColName), col(probabilityColName)).rdd.map{ | |||
case Row(label: Double, pred: Double, prob: DenseVector) => (label, pred, prob.toArray) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider staying in the Dataframe API as long as possible before resorting to RDD
Added confusion matrix for multi-class classification, including: