Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pretty print model summaries #25

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
A handful of model insight helper methods
  • Loading branch information
tovbinm committed Jun 22, 2018
commit 35dcd14e5809a68fa43c7bd90492763d89f3846c
122 changes: 120 additions & 2 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,33 @@

package com.salesforce.op

import com.salesforce.op.evaluators._
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types.{OPVector, RealNN}
import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry
import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{DecisionTree, LogisticRegression, NaiveBayes, RandomForest}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, SelectedModel}
import com.salesforce.op.stages.impl.regression.RegressionModelsToTry
import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.{DecisionTreeRegression, GBTRegression, LinearRegression, RandomForestRegression}
import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, ModelSelectorBaseNames, SelectedModel}
import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames._
import com.salesforce.op.stages.{OPStage, OpPipelineStageParams, OpPipelineStageParamsNames}
import com.salesforce.op.utils.json.JsonUtils
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.RichMetadata._
import enumeratum.EnumEntry
import org.apache.spark.ml.classification._
import org.apache.spark.ml.regression._
import org.apache.spark.ml.{Model, PipelineStage, Transformer}
import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.sql.types.Metadata
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.json4s.jackson.Serialization.{write, writePretty}
import org.slf4j.LoggerFactory

import scala.util.Try
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

/**
* Summary of all model insights
Expand All @@ -68,10 +78,118 @@ case class ModelInsights
stageInfo: Map[String, Any]
) {

/**
* Best model UID
*/
def bestModelUid: String = selectedModelInfo(BestModelUid).toString

/**
* Best model name
*/
def bestModelName: String = selectedModelInfo(BestModelName).toString

/**
* Best model type, i.e. LogisticRegression, RandomForest etc.
*/
def bestModelType: EnumEntry = {
classificationModelTypeOfUID.orElse(regressionModelTypeOfUID).lift(bestModelUid).getOrElse(
throw new Exception(s"Unsupported model type for best model '$bestModelUid'"))
}

/**
* Best model validation results computed during Cross Validation or Train Validation Split
*/
def bestModelValidationResults: Map[String, String] = validationResults(bestModelName)

/**
* Validation results computed during Cross Validation or Train Validation Split
*
* @return validation results keyed by model name
*/
def validationResults: Map[String, Map[String, String]] = {
val res = for {
results <- getMap[String, Any](selectedModelInfo, TrainValSplitResults).recoverWith {
case e => getMap[String, Any](selectedModelInfo, CrossValResults)
}
} yield results.keys.map(k => k -> getMap[String, String](results, k).getOrElse(Map.empty))
res match {
case Failure(e) => throw new Exception(s"Failed to extract validation results", e)
case Success(ok) => ok.toMap
}
}

/**
* Train set evaluation metrics
*/
def trainEvaluationMetrics: EvaluationMetrics = evaluationMetrics(TrainingEval)

/**
* Test set evaluation metrics (if any)
*/
def testEvaluationMetrics: Option[EvaluationMetrics] = {
selectedModelInfo.get(HoldOutEval).map(_ => evaluationMetrics(HoldOutEval))
}

/**
* Serialize to json string
*
* @param pretty should pretty format
* @return json string
*/
def toJson(pretty: Boolean = true): String = {
implicit val formats = DefaultFormats
if (pretty) writePretty(this) else write(this)
}

private def classificationModelTypeOfUID: PartialFunction[String, ClassificationModelsToTry] = {
case uid if uid.startsWith("logreg") => LogisticRegression
case uid if uid.startsWith("rfc") => RandomForest
case uid if uid.startsWith("dtc") => DecisionTree
case uid if uid.startsWith("nb") => NaiveBayes
}
private def regressionModelTypeOfUID: PartialFunction[String, RegressionModelsToTry] = {
case uid if uid.startsWith("linReg") => LinearRegression
case uid if uid.startsWith("rfr") => RandomForestRegression
case uid if uid.startsWith("dtr") => DecisionTreeRegression
case uid if uid.startsWith("gbtr") => GBTRegression
}
private def evaluationMetrics(metricsName: String): EvaluationMetrics = {
val res = for {
metricsMap <- getMap[String, Double](selectedModelInfo, metricsName)
evalMetrics <- Try(toEvaluationMetrics(metricsMap))
} yield evalMetrics
res match {
case Failure(e) => throw new Exception(s"Failed to extract '$metricsName' metrics", e)
case Success(ok) => ok
}
}
private def getMap[K, V](m: Map[String, Any], name: String): Try[Map[K, V]] = Try {
m(name) match {
case m: Map[String, Any]@unchecked => m("map").asInstanceOf[Map[K, V]]
case m: Metadata => m.underlyingMap.asInstanceOf[Map[K, V]]
}
}

private val MetricName = "\\((.*)\\)\\_(.*)".r

private def toEvaluationMetrics(metrics: Map[String, Double]): EvaluationMetrics = {
import OpEvaluatorNames._
val metricsType = metrics.keys.headOption match {
case Some(MetricName(t, _)) if Set(binary, multi, regression).contains(t) => t
case v => throw new Exception(s"Invalid model metric '$v'")
}
def parse[T <: EvaluationMetrics : ClassTag] = {
val vals = metrics.map { case (MetricName(_, name), value) => name -> value }
val valsJson = JsonUtils.toJsonString(vals)
JsonUtils.fromString[T](valsJson).get
}
metricsType match {
case `binary` => parse[BinaryClassificationMetrics]
case `multi` => parse[MultiClassificationMetrics]
case `regression` => parse[RegressionMetrics]
case t => throw new Exception(s"Unsupported metrics type '$t'")
}
}
}

/**
Expand Down
22 changes: 19 additions & 3 deletions core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase}
import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.features.{FeatureLike, OPFeature}
import com.salesforce.op.readers.DataFrameFieldNames._
import com.salesforce.op.stages.impl.selector.StageParamNames
import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichMetadata._
Expand Down Expand Up @@ -165,7 +166,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
}

/**
* Pulls all summary metadata off of transformers
* Pulls all summary metadata of transformers and puts them in json
*
* @return json summary
*/
Expand All @@ -177,12 +178,27 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams
)

/**
* Pulls all summary metadata off of transformers and puts them in a pretty json string
* Pulls all summary metadata of transformers and puts them into json string
*
* @return string summary
* @return json string summary
*/
def summary(): String = pretty(render(summaryJson()))

/**
* Pulls all summary metadata of transformers and puts them into compact print friendly string
*
* @return compact print friendly string
*/
def summaryPretty(): String = {
val prediction = resultFeatures.find(_.name == StageParamNames.outputParam1Name).orElse(
stages.map(_.getOutput()).find(_.name == StageParamNames.outputParam1Name)
).getOrElse(
throw new Exception("No prediction feature is defined")
)
val insights = modelInsights(prediction)
???
}

/**
* Save this model to a path
*
Expand Down
37 changes: 31 additions & 6 deletions core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@

package com.salesforce.op

import com.salesforce.op.evaluators.BinaryClassificationMetrics
import com.salesforce.op.features.Feature
import com.salesforce.op.features.types.{FeatureTypeDefaults, PickList, Real, RealNN}
import com.salesforce.op.features.types.{PickList, Real, RealNN}
import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector
import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.LogisticRegression
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.RegressionModelSelector
import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.LinearRegression
import com.salesforce.op.stages.impl.selector.SelectedModel
import com.salesforce.op.stages.impl.tuning.DataSplitter
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import org.junit.runner.RunWith
Expand All @@ -55,11 +57,11 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
private val descrVec = description.vectorize(10, false, 1, true)
private val features = Seq(density, age, generVec, weight, descrVec).transmogrify()
private val label = survived.occurs()
private val checked = label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false,
checkSample = 1.0)
private val checked =
label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false, checkSample = 1.0)

val (pred, rawPred, prob) = BinaryClassificationModelSelector
.withCrossValidation(seed = 42, splitter = None)
.withCrossValidation(seed = 42, splitter = Option(DataSplitter(seed = 42, reserveTestFraction = 0.1)))
.setModelsToTry(LogisticRegression)
.setLogisticRegressionRegParam(0.01, 0.1)
.setInput(label, checked)
Expand Down Expand Up @@ -119,7 +121,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
insights.label.rawFeatureName shouldBe Seq(survived.name)
insights.label.rawFeatureType shouldBe Seq(survived.typeName)
insights.label.stagesApplied.size shouldBe 1
insights.label.sampleSize shouldBe Some(6.0)
insights.label.sampleSize shouldBe Some(4.0)
insights.features.size shouldBe 5
insights.features.map(_.featureName).toSet shouldEqual rawNames
ageInsights.derivedFeatures.size shouldBe 2
Expand Down Expand Up @@ -170,7 +172,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
insights.label.rawFeatureName shouldBe Seq(survived.name)
insights.label.rawFeatureType shouldBe Seq(survived.typeName)
insights.label.stagesApplied.size shouldBe 1
insights.label.sampleSize shouldBe Some(6.0)
insights.label.sampleSize shouldBe Some(4.0)
insights.features.size shouldBe 5
insights.features.map(_.featureName).toSet shouldEqual rawNames
ageInsights.derivedFeatures.size shouldBe 2
Expand Down Expand Up @@ -237,6 +239,29 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest {
lin.head.size shouldBe OpVectorMetadata("", checked.originStage.getMetadata()).columns.length
}

it should "return best model information" in {
val insights = workflowModel.modelInsights(prob)
insights.bestModelUid should startWith("logreg_")
insights.bestModelName should startWith("logreg_")
insights.bestModelType shouldBe LogisticRegression
val bestModelValidationResults = insights.bestModelValidationResults
bestModelValidationResults.size shouldBe 15
println(bestModelValidationResults)
bestModelValidationResults.get("area under PR") shouldBe Some("0.0")
val validationResults = insights.validationResults
validationResults.size shouldBe 2
validationResults.get(insights.bestModelName) shouldBe Some(bestModelValidationResults)
}

it should "return test/train evaluation metrics" in {
val insights = workflowModel.modelInsights(prob)
insights.trainEvaluationMetrics shouldBe
BinaryClassificationMetrics(1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 5.0, 0.0, 0.0)
insights.testEvaluationMetrics shouldBe Some(
BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.5, 0.75, 0.5, 0.0, 1.0, 0.0, 1.0)
)
}

it should "correctly serialize and deserialize from json" in {
val insights = workflowModel.modelInsights(prob)
ModelInsights.fromJson(insights.toJson()) match {
Expand Down