-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update confusion_matrix #545
Conversation
Thanks for the contribution! It looks like @swkasula is an internal user so signing the CLA is not required. However, we need to confirm this. |
Codecov Report
@@ Coverage Diff @@
## master #545 +/- ##
===========================================
+ Coverage 0 86.78% +86.78%
===========================================
Files 0 347 +347
Lines 0 12026 +12026
Branches 0 403 +403
===========================================
+ Hits 0 10437 +10437
- Misses 0 1589 +1589
Continue to review full report at Codecov.
|
Thanks, @swkasula Please confirm the OSS agreement. |
Threshold: Double, | ||
@JsonDeserialize(contentAs = classOf[java.lang.Long]) | ||
ConfusionMatrixCounts: Seq[Long] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor style issue: params are typically lower camel case:
Threshold: Double, | |
@JsonDeserialize(contentAs = classOf[java.lang.Long]) | |
ConfusionMatrixCounts: Seq[Long] | |
threshold: Double, | |
@JsonDeserialize(contentAs = classOf[java.lang.Long]) | |
confusionMatrixCounts: Seq[Long] |
@@ -418,12 +418,12 @@ class OpMultiClassificationEvaluatorTest extends FlatSpec with TestSparkContext | |||
outputMetrics.ConfMatrixThresholds shouldEqual testThresholds | |||
outputMetrics.ConfMatrices.length shouldEqual testThresholds.length | |||
// topK confusion matrix for p >= 0.4 | |||
outputMetrics.ConfMatrices(0) shouldEqual | |||
outputMetrics.ConfMatrices(0).ConfusionMatrixCounts shouldEqual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a line here verifying the threshold? Something like outputMetrics.ConfMatrices(0).Threshold shouldEqual....
Seq( | ||
6L, 6L, | ||
4L, 4L) | ||
// topK confusion matrix for p >= 0.7 | ||
outputMetrics.ConfMatrices(1).toArray shouldEqual | ||
outputMetrics.ConfMatrices(1).ConfusionMatrixCounts.toArray shouldEqual |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above. Otherwise lgtm
closing this as I created a different PR: #549 to merge from the tested dev branch. |
Related issues
De-serializing the ConfMatrices: Seq[Seq[Long]] resulting in ClassCastException as there was no @JsonDeserialize
annotation used for this attribute: https://github.com/salesforce/TransmogrifAI/pull/533/files
Describe the proposed solution
Fix is it update the ConfMatrices format from Seq[Seq[Long]] -> Seq[ConfusionMatrixPerThreshold] to de-serialize in the right format