Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve json serde error in evalMetFromJson #380

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,16 @@ package com.salesforce.op.stages.impl.selector

import com.salesforce.op.evaluators._
import com.salesforce.op.stages.impl.MetadataLike
import com.salesforce.op.stages.impl.selector.ModelSelectorSummary._
import com.salesforce.op.stages.impl.tuning.{OpCrossValidation, OpTrainValidationSplit, OpValidator, SplitterSummary}
import com.salesforce.op.utils.json.JsonUtils
import com.salesforce.op.utils.reflection.ReflectionUtils
import com.salesforce.op.utils.spark.RichMetadata._
import enumeratum._
import org.apache.spark.sql.types.{Metadata, MetadataBuilder}
import com.salesforce.op.stages.impl.selector.ModelSelectorSummary._
import com.salesforce.op.utils.json.JsonUtils
import com.salesforce.op.utils.reflection.ReflectionUtils

import scala.reflect.ClassTag
import scala.util.{Failure, Try}

/**
* This is used to store all information about fitting and model selection generated by the model selector class
Expand Down Expand Up @@ -171,7 +173,7 @@ case object ModelSelectorSummary {
val modelName: String = wrapped.get[String](ModelName)
val modelType: String = wrapped.get[String](ModelTypeName)
val Array(metName, metJson) = wrapped.get[Array[String]](MetricValues)
val metricValues: EvaluationMetrics = evalMetFromJson(metName, metJson)
val metricValues: EvaluationMetrics = evalMetFromJson(metName, metJson).get
val modelParameters: Map[String, Any] = wrapped.get[Metadata](ModelParameters).wrapped.underlyingMap

ModelEvaluation(
Expand Down Expand Up @@ -203,11 +205,10 @@ case object ModelSelectorSummary {
val validationResults: Seq[ModelEvaluation] = wrapped.get[Array[Metadata]](ValidationResults)
.map(modelEvalFromMetadata)
val Array(metName, metJson) = wrapped.get[Array[String]](TrainEvaluation)
val trainEvaluation: EvaluationMetrics = evalMetFromJson(metName, metJson)
val holdoutEvaluation: Option[EvaluationMetrics] =
if (wrapped.contains(HoldoutEvaluation)) {
val Array(metNameHold, metJsonHold) = wrapped.get[Array[String]](HoldoutEvaluation)
Option(evalMetFromJson(metNameHold, metJsonHold))
evalMetFromJson(metNameHold, metJsonHold).toOption
} else None

ModelSelectorSummary(
Expand All @@ -221,7 +222,7 @@ case object ModelSelectorSummary {
bestModelName = bestModelName,
bestModelType = bestModelType,
validationResults = validationResults,
trainEvaluation = trainEvaluation,
trainEvaluation = evalMetFromJson(metName, metJson).get,
holdoutEvaluation = holdoutEvaluation)

}
Expand All @@ -231,12 +232,12 @@ case object ModelSelectorSummary {
*
* @param json encoded metrics
*/
private def evalMetFromJson(className: String, json: String): EvaluationMetrics = {
def error(c: Class[_]) = throw new IllegalArgumentException(
s"Could not extract metrics of type $c from ${json.mkString(",")}"
)
val classZZ = ReflectionUtils.classForName(className)
classZZ match {
private[selector] def evalMetFromJson(className: String, json: String): Try[EvaluationMetrics] = {
def error(c: Class[_], t: Throwable): Try[MultiMetrics] = Failure[MultiMetrics] {
new IllegalArgumentException(s"Could not extract metrics of type $c from: $json", t)
}

ReflectionUtils.classForName(className) match {
case n if n == classOf[MultiMetrics] =>
JsonUtils.fromString[Map[String, Map[String, Any]]](json).map{ d =>
val asMetrics = d.flatMap{ case (_, values) => values.map{
Expand All @@ -255,11 +256,11 @@ case object ModelSelectorSummary {
}}
}
MultiMetrics(asMetrics)
}.getOrElse(error(classOf[MultiMetrics]))
case n => JsonUtils.fromString(json)(ClassTag(n)).getOrElse(error(n))
}.recoverWith { case t: Throwable => error(n, t) }
case n => JsonUtils.fromString(json)(ClassTag(n))
.recoverWith { case t: Throwable => error(n, t) }
}
}

}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,24 @@ class ModelSelectorSummaryTest extends FlatSpec with TestSparkContext {
decoded.holdoutEvaluation shouldEqual summary.holdoutEvaluation
}

it should "not hide the root cause of JSON parsing errors" in {
val evalMetrics = MultiClassificationMetrics(Precision = 0.1, Recall = 0.2, F1 = 0.3, Error = 0.4,
ThresholdMetrics = ThresholdMetrics(topNs = Seq(1, 2), thresholds = Seq(1.1, 1.2),
correctCounts = Map(1 -> Seq(100L)), incorrectCounts = Map(2 -> Seq(200L)),
noPredictionCounts = Map(3 -> Seq(300L))))

val evalMetricsJson = evalMetrics.toJson()
val roundTripEvalMetrics = ModelSelectorSummary.evalMetFromJson(
classOf[MultiClassificationMetrics].getName, evalMetricsJson).get
roundTripEvalMetrics shouldBe evalMetrics

val corruptJson = evalMetricsJson.replace(":", "=")
val thr = intercept[IllegalArgumentException](ModelSelectorSummary.evalMetFromJson(
classOf[MultiClassificationMetrics].getName, corruptJson).get)

thr.getMessage should startWith ("Could not extract metrics of type class " +
"com.salesforce.op.evaluators.MultiClassificationMetrics from: {")

thr.getCause.getMessage shouldEqual "Unsupported format. Supported formats: json, yaml"
}
}