diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index a913e1a0c2..1caac95574 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -31,23 +31,38 @@ package com.salesforce.op +import com.salesforce.op.evaluators._ import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types.{OPVector, RealNN} +import com.salesforce.op.stages.impl.ModelsToTry +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{DecisionTree, LogisticRegression, NaiveBayes, RandomForest} +import com.salesforce.op.stages.impl.feature.TransmogrifierDefaults import com.salesforce.op.stages.impl.preparators._ +import com.salesforce.op.stages.impl.regression.RegressionModelsToTry +import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.{DecisionTreeRegression, GBTRegression, LinearRegression, RandomForestRegression} +import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames._ import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, SelectedModel} import com.salesforce.op.stages.{OPStage, OpPipelineStageParams, OpPipelineStageParamsNames} -import com.salesforce.op.utils.spark.OpVectorMetadata +import com.salesforce.op.utils.json.JsonUtils +import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.salesforce.op.utils.spark.RichMetadata._ +import com.salesforce.op.utils.table.Table +import enumeratum._ import org.apache.spark.ml.classification._ import org.apache.spark.ml.regression._ import org.apache.spark.ml.{Model, PipelineStage, Transformer} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.sql.types.Metadata import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization.{write, writePretty} import org.slf4j.LoggerFactory +import com.salesforce.op.utils.table.Alignment._ -import scala.util.Try +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} /** * Summary of all model insights @@ -68,10 +83,341 @@ case class ModelInsights stageInfo: Map[String, Any] ) { + /** + * Selected model UID + */ + def selectedModelUID: String = selectedModelInfo(BestModelUid).toString + + /** + * Selected model name + */ + def selectedModelName: String = selectedModelInfo(BestModelName).toString + + /** + * Selected model type, i.e. LogisticRegression, RandomForest etc. + */ + def selectedModelType: ModelsToTry = modelType(selectedModelName).get + + /** + * Selected model validation results computed during Cross Validation or Train Validation Split + */ + def selectedModelValidationResults: Map[String, String] = validationResults(selectedModelName) + + /** + * Train set evaluation metrics for selected model + */ + def selectedModelTrainEvalMetrics: EvaluationMetrics = evaluationMetrics(TrainingEval) + + /** + * Test set evaluation metrics (if any) for selected model + */ + def selectedModelTestEvalMetrics: Option[EvaluationMetrics] = { + selectedModelInfo.get(HoldOutEval).map(_ => evaluationMetrics(HoldOutEval)) + } + + /** + * Validation results for all models computed during Cross Validation or Train Validation Split + * + * @return validation results keyed by model name + */ + def validationResults: Map[String, Map[String, String]] = { + val res = for { + results <- getMap[String, Any](selectedModelInfo, TrainValSplitResults).recoverWith { + case e => getMap[String, Any](selectedModelInfo, CrossValResults) + } + } yield results.keys.map(k => k -> getMap[String, String](results, k).getOrElse(Map.empty)) + res match { + case Failure(e) => throw new Exception(s"Failed to extract validation results", e) + case Success(ok) => ok.toMap + } + } + + /** + * Validation results for a specified model type computed during Cross Validation or Train Validation Split + * + * @return validation results keyed by model name + */ + def validationResults(mType: ModelsToTry): Map[String, Map[String, String]] = { + validationResults.filter { case (modelName, _) => modelType(modelName).toOption.contains(mType) } + } + + /** + * All validated model types + */ + def validatedModelTypes: Set[ModelsToTry] = + validationResults.keys.flatMap(modelName => modelType(modelName).toOption).toSet + + /** + * Validation type, i.e TrainValidationSplit, CrossValidation + */ + def validationType: ValidationType = { + if (getMap[String, Any](selectedModelInfo, TrainValSplitResults).isSuccess) ValidationType.TrainValidationSplit + else if (getMap[String, Any](selectedModelInfo, CrossValResults).isSuccess) ValidationType.CrossValidation + else throw new Exception(s"Failed to determine validation type") + } + + /** + * Evaluation metric type, i.e. AuPR, AuROC, F1 etc. + */ + def evaluationMetricType: EnumEntry with EvalMetric = { + val knownEvalMetrics = { + (BinaryClassEvalMetrics.values ++ MultiClassEvalMetrics.values ++ RegressionEvalMetrics.values) + .map(m => m.humanFriendlyName -> m).toMap + } + val evalMetrics = validationResults.flatMap(_._2.keys).flatMap(knownEvalMetrics.get).toSet.toList + evalMetrics match { + case evalMetric :: Nil => evalMetric + case Nil => throw new Exception("Unable to determine evaluation metric type: no metrics were found") + case metrics => throw new Exception( + s"Unable to determine evaluation metric type since: multiple metrics were found - " + metrics.mkString(",")) + } + } + + /** + * Problem type, i.e. Binary Classification, Multi Classification or Regression + */ + def problemType: ProblemType = selectedModelTrainEvalMetrics match { + case _: BinaryClassificationMetrics => ProblemType.BinaryClassification + case _: MultiClassificationMetrics => ProblemType.MultiClassification + case _: RegressionMetrics => ProblemType.Regression + case _ => ProblemType.Unknown + } + + /** + * Serialize to json string + * + * @param pretty should pretty format + * @return json string + */ def toJson(pretty: Boolean = true): String = { implicit val formats = DefaultFormats if (pretty) writePretty(this) else write(this) } + + /** + * High level model summary in a compact print friendly format containing: + * selected model info, model evaluation results and feature correlations/contributions/cramersV values. + * + * @param topK top K of feature correlations/contributions/cramersV values + * @return high level model summary in a compact print friendly format + */ + def prettyPrint(topK: Int = 15): String = { + val res = new ArrayBuffer[String]() + res ++= prettyValidationResults + res += prettySelectedModelInfo + res += modelEvaluationMetrics + res ++= topKCorrelations(topK) + res ++= topKContributions(topK) + res ++= topKCramersV(topK) + res.mkString("\n") + } + + private def prettyValidationResults: Seq[String] = { + val evalSummary = { + val vModelTypes = validatedModelTypes + "Evaluated %s model%s using %s and %s metric.".format( + vModelTypes.mkString(", "), + if (vModelTypes.size > 1) "s" else "", + validationType.humanFriendlyName, // TODO add number of folds or train/split ratio if possible + evaluationMetricType.humanFriendlyName + ) + } + val modelEvalRes = for { + modelType <- validatedModelTypes + modelValidationResults = validationResults(modelType) + evalMetric = evaluationMetricType.humanFriendlyName + } yield { + val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => + metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) + } + val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) + val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) + + "Evaluated %d %s model%s with %s metric between [%s, %s].".format( + modelValidationResults.size, + modelType, + if (modelValidationResults.size > 1) "s" else "", + evalMetric, + minMetricValue, + maxMetricValue + ) + } + Seq(evalSummary, modelEvalRes.mkString("\n")) + } + + private def prettySelectedModelInfo: String = { + val bestModelType = selectedModelType + val name = s"Selected Model - $bestModelType" + val validationResults = selectedModelValidationResults.toSeq ++ Seq( + "name" -> selectedModelName, + "uid" -> selectedModelUID, + "modelType" -> selectedModelType + ) + val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) + table.prettyString() + } + + private def modelEvaluationMetrics: String = { + val name = "Model Evaluation Metrics" + val trainEvalMetrics = selectedModelTrainEvalMetrics + val testEvalMetrics = selectedModelTestEvalMetrics + val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") + val trainMetrics = trainEvalMetrics.toMap.collect { case (k, v: Double) => k -> v.toString }.toSeq.sortBy(_._1) + val table = testEvalMetrics match { + case Some(testMetrics) => + val testMetricsMap = testMetrics.toMap + val rows = trainMetrics.map { case (k, v) => (k, v, testMetricsMap(k).toString) } + Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) + case None => + Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) + } + table.prettyString() + } + + private def topKInsights(s: Seq[(FeatureInsights, Insights, Double)], topK: Int): Seq[(String, Double)] = { + s.foldLeft(Seq.empty[(String, Double)]) { + case (acc, (feature, derived, corr)) => + val insightValue = derived.derivedFeatureGroup -> derived.derivedFeatureValue match { + case (Some(group), Some(OpVectorColumnMetadata.NullString)) => s"${feature.featureName}($group = null)" + case (Some(group), Some(TransmogrifierDefaults.OtherString)) => s"${feature.featureName}($group = other)" + case (Some(group), Some(value)) => s"${feature.featureName}($group = $value)" + case (Some(group), None) => s"${feature.featureName}(group = $group)" // should not happen + case (None, Some(value)) => s"${feature.featureName}(value = $value)" // should not happen + case (None, None) => feature.featureName + } + if (acc.exists(_._1 == insightValue)) acc else acc :+ (insightValue, corr) + } take topK + } + + private def topKCorrelations(topK: Int): Seq[String] = { + val corrs = for { + (feature, derived) <- derivedNonExcludedFeatures + } yield (feature, derived, derived.corr.collect { case v if !v.isNaN => v }) + + val corrDsc = corrs.map { case (f, d, corr) => (f, d, corr.getOrElse(Double.MinValue)) }.sortBy(_._3).reverse + val corrAsc = corrs.map { case (f, d, corr) => (f, d, corr.getOrElse(Double.MaxValue)) }.sortBy(_._3) + val topPositiveCorrs = topKInsights(corrDsc, topK) + val topNegativeCorrs = topKInsights(corrAsc, topK).filterNot(topPositiveCorrs.contains) + + val correlationCol = "Correlation Value" + + lazy val topPositive = Table( + name = "Top Model Insights", + columns = Seq("Top Positive Correlations", correlationCol), + rows = topPositiveCorrs + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + lazy val topNegative = Table( + columns = Seq("Top Negative Correlations", correlationCol), + rows = topNegativeCorrs + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + if (topNegativeCorrs.isEmpty) Seq(topPositive) else Seq(topPositive, topNegative) + } + + private def topKContributions(topK: Int): Option[String] = { + val contribs = for { + (feature, derived) <- derivedNonExcludedFeatures + contrib = math.abs(derived.contribution.reduceOption[Double](math.max).getOrElse(0.0)) + } yield (feature, derived, contrib) + + val contribDesc = contribs.sortBy(_._3).reverse + val rows = topKInsights(contribDesc, topK) + numericalTable(columns = Seq("Top Contributions", "Contribution Value"), rows) + } + + private def topKCramersV(topK: Int): Option[String] = { + val cramersV = for { + (feature, derived) <- derivedNonExcludedFeatures + group <- derived.derivedFeatureGroup + cramersV <- derived.cramersV + } yield group -> cramersV + + val topCramersV = cramersV.distinct.sortBy(_._2).reverse.take(topK) + numericalTable(columns = Seq("Top CramersV", "CramersV"), rows = topCramersV) + } + + private def derivedNonExcludedFeatures: Seq[(FeatureInsights, Insights)] = { + for { + feature <- features + derived <- feature.derivedFeatures + if !derived.excluded.contains(true) + } yield feature -> derived + } + + private def numericalTable(columns: Seq[String], rows: Seq[(String, Double)]): Option[String] = + if (rows.isEmpty) None else Some(Table(columns, rows).prettyString(columnAlignments = Map(columns.last -> Right))) + + private def modelType(modelName: String): Try[ModelsToTry] = Try { + classificationModelType.orElse(regressionModelType).lift(modelName).getOrElse( + throw new Exception(s"Unsupported model type for best model '$modelName'")) + } + + private def classificationModelType: PartialFunction[String, ClassificationModelsToTry] = { + case v if v.startsWith("logreg") => LogisticRegression + case v if v.startsWith("rfc") => RandomForest + case v if v.startsWith("dtc") => DecisionTree + case v if v.startsWith("nb") => NaiveBayes + } + private def regressionModelType: PartialFunction[String, RegressionModelsToTry] = { + case v if v.startsWith("linReg") => LinearRegression + case v if v.startsWith("rfr") => RandomForestRegression + case v if v.startsWith("dtr") => DecisionTreeRegression + case v if v.startsWith("gbtr") => GBTRegression + } + private def evaluationMetrics(metricsName: String): EvaluationMetrics = { + val res = for { + metricsMap <- getMap[String, Any](selectedModelInfo, metricsName) + evalMetrics <- Try(toEvaluationMetrics(metricsMap)) + } yield evalMetrics + res match { + case Failure(e) => throw new Exception(s"Failed to extract '$metricsName' metrics", e) + case Success(ok) => ok + } + } + private def getMap[K, V](m: Map[String, Any], name: String): Try[Map[K, V]] = Try { + m(name) match { + case m: Map[String, Any]@unchecked => m("map").asInstanceOf[Map[K, V]] + case m: Metadata => m.underlyingMap.asInstanceOf[Map[K, V]] + } + } + + private val MetricName = "\\((.*)\\)\\_(.*)".r + + private def toEvaluationMetrics(metrics: Map[String, Any]): EvaluationMetrics = { + import OpEvaluatorNames._ + val metricsType = metrics.keys.headOption match { + case Some(MetricName(t, _)) if Set(binary, multi, regression).contains(t) => t + case v => throw new Exception(s"Invalid model metric '$v'") + } + def parse[T <: EvaluationMetrics : ClassTag] = { + val vals = metrics.map { case (MetricName(_, name), value) => name -> value } + val valsJson = JsonUtils.toJsonString(vals) + JsonUtils.fromString[T](valsJson).get + } + metricsType match { + case `binary` => parse[BinaryClassificationMetrics] + case `multi` => parse[MultiClassificationMetrics] + case `regression` => parse[RegressionMetrics] + case t => throw new Exception(s"Unsupported metrics type '$t'") + } + } +} + +sealed trait ProblemType extends EnumEntry with Serializable + object ProblemType extends Enum[ProblemType] { + val values = findValues + case object BinaryClassification extends ProblemType + case object MultiClassification extends ProblemType + case object Regression extends ProblemType + case object Unknown extends ProblemType +} + +sealed abstract class ValidationType(val humanFriendlyName: String) extends EnumEntry with Serializable +object ValidationType extends Enum[ValidationType] { + val values = findValues + case object CrossValidation extends ValidationType("Cross Validation") + case object TrainValidationSplit extends ValidationType("Train Validation Split") } /** diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 665dde3277..6c23e0acf3 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -39,7 +39,6 @@ import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ import com.salesforce.op.utils.stages.FitStagesUtil -import org.apache.spark.ml.Estimator import org.apache.spark.sql.types.Metadata import org.apache.spark.sql.{DataFrame, SparkSession} import org.json4s.JValue @@ -170,7 +169,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams } /** - * Pulls all summary metadata off of transformers + * Pulls all summary metadata of transformers and puts them in json * * @return json summary */ @@ -182,12 +181,28 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams ) /** - * Pulls all summary metadata off of transformers and puts them in a pretty json string + * Pulls all summary metadata of transformers and puts them into json string * - * @return string summary + * @return json string summary */ def summary(): String = pretty(render(summaryJson())) + /** + * High level model summary in a compact print friendly format containing: + * selected model info, model evaluation results and feature correlations/contributions/cramersV values. + * + * @param insights model insights to compute the summary against + * @param topK top K of feature correlations/contributions/cramersV values to print + * @return high level model summary in a compact print friendly format + */ + def summaryPretty( + insights: ModelInsights = modelInsights( + resultFeatures.find(f => f.isResponse && !f.isRaw).getOrElse( + throw new IllegalArgumentException("No response feature is defined to compute model insights")) + ), + topK: Int = 15 + ): String = insights.prettyPrint(topK) + /** * Save this model to a path * diff --git a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala index 3645b45480..9d9b839055 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala @@ -57,7 +57,8 @@ object Evaluators { * Area under ROC */ def auROC(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.auROC, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = BinaryClassEvalMetrics.AuROC.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuROC, dataset) } @@ -66,7 +67,7 @@ object Evaluators { * Area under Precision/Recall curve */ def auPR(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.auPR, isLargerBetter = true) { + new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuPR, dataset) } @@ -75,7 +76,8 @@ object Evaluators { * Precision */ def precision(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.precision, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).precision(1.0) @@ -86,7 +88,8 @@ object Evaluators { * Recall */ def recall(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.recall, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).recall(1.0) @@ -97,7 +100,7 @@ object Evaluators { * F1 score */ def f1(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.f1, isLargerBetter = true) { + new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics( @@ -109,7 +112,8 @@ object Evaluators { * Prediction error */ def error(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.error, isLargerBetter = false) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = 1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset) } @@ -162,7 +166,8 @@ object Evaluators { * Weighted Precision */ def precision(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.precision, isLargerBetter = true) { + new OpMultiClassificationEvaluator( + name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.Precision, dataset) } @@ -171,7 +176,7 @@ object Evaluators { * Weighted Recall */ def recall(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.recall, isLargerBetter = true) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.Recall, dataset) } @@ -180,7 +185,7 @@ object Evaluators { * F1 Score */ def f1(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.f1, isLargerBetter = true) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.F1, dataset) } @@ -189,7 +194,7 @@ object Evaluators { * Prediction Error */ def error(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.error, isLargerBetter = false) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = 1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset) } @@ -252,7 +257,8 @@ object Evaluators { * Mean Squared Error */ def mse(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.meanSquaredError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.MeanSquaredError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.MeanSquaredError, dataset) } @@ -261,7 +267,8 @@ object Evaluators { * Mean Absolute Error */ def mae(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.meanAbsoluteError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.MeanAbsoluteError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.MeanAbsoluteError, dataset) } @@ -270,7 +277,7 @@ object Evaluators { * R2 */ def r2(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.r2, isLargerBetter = true) { + new OpRegressionEvaluator(name = RegressionEvalMetrics.R2.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.R2, dataset) } @@ -279,7 +286,8 @@ object Evaluators { * Root Mean Squared Error */ def rmse(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.rootMeanSquaredError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.RootMeanSquaredError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.RootMeanSquaredError, dataset) } diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala index 9fbe841ecd..72d78da2bd 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala @@ -195,21 +195,41 @@ abstract class OpRegressionEvaluatorBase[T <: EvaluationMetrics] with OpHasLabelCol[RealNN] with OpHasPredictionCol[RealNN] +/** + * Eval metric + */ +trait EvalMetric extends Serializable { + /** + * Spark metric name + */ + def sparkEntryName: String + /** + * Human friendly metric name + */ + def humanFriendlyName: String +} -sealed abstract class ClassificationEvalMetric(val sparkEntryName: String) extends EnumEntry with Serializable +/** + * Classification Metrics + */ +sealed abstract class ClassificationEvalMetric +( + val sparkEntryName: String, + val humanFriendlyName: String +) extends EnumEntry with EvalMetric /** * Binary Classification Metrics */ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] { val values = findValues - case object Precision extends ClassificationEvalMetric("precision") - case object Recall extends ClassificationEvalMetric("recall") - case object F1 extends ClassificationEvalMetric("f1") - case object Error extends ClassificationEvalMetric("accuracy") - case object AuROC extends ClassificationEvalMetric("areaUnderROC") - case object AuPR extends ClassificationEvalMetric("areaUnderPR") + case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision") + case object Recall extends ClassificationEvalMetric("weightedRecall", "recall") + case object F1 extends ClassificationEvalMetric("f1", "f1") + case object Error extends ClassificationEvalMetric("accuracy", "error") + case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC") + case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under PR") } /** @@ -217,33 +237,29 @@ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] { */ object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] { val values = findValues - case object Precision extends ClassificationEvalMetric("weightedPrecision") - case object Recall extends ClassificationEvalMetric("weightedRecall") - case object F1 extends ClassificationEvalMetric("f1") - case object Error extends ClassificationEvalMetric("accuracy") - case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics") + case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision") + case object Recall extends ClassificationEvalMetric("weightedRecall", "recall") + case object F1 extends ClassificationEvalMetric("f1", "f1") + case object Error extends ClassificationEvalMetric("accuracy", "error") + case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics") } /** - * Contains the names of metrics used in logging + * Regression Metrics */ -private[op] case object OpMetricsNames { - val rootMeanSquaredError = "root mean square error" - val meanSquaredError = "mean square error" - val meanAbsoluteError = "mean absolute error" - val r2 = "r2" - val auROC = "area under ROC" - val auPR = "area under PR" - val precision = "precision" - val recall = "recall" - val f1 = "f1" - val accuracy = "accuracy" - val error = "error" - val tp = "true positive" - val tn = "true negative" - val fp = "false positive" - val fn = "false negative" +sealed abstract class RegressionEvalMetric +( + val sparkEntryName: String, + val humanFriendlyName: String +) extends EnumEntry with EvalMetric + +object RegressionEvalMetrics extends Enum[RegressionEvalMetric] { + val values: Seq[RegressionEvalMetric] = findValues + case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error") + case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error") + case object R2 extends RegressionEvalMetric("r2", "r2") + case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error") } /** @@ -254,5 +270,3 @@ case object OpEvaluatorNames { val multi = "multiEval" val regression = "regEval" } - - diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala index 5ed64f3a18..5a4e475e34 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala @@ -101,13 +101,3 @@ case class RegressionMetrics R2: Double, MeanAbsoluteError: Double ) extends EvaluationMetrics - -/* Regression Metrics */ -sealed abstract class RegressionEvalMetric(val sparkEntryName: String) extends EnumEntry with Serializable -object RegressionEvalMetrics extends Enum[RegressionEvalMetric] { - val values: Seq[RegressionEvalMetric] = findValues - case object RootMeanSquaredError extends RegressionEvalMetric("rmse") - case object MeanSquaredError extends RegressionEvalMetric("mse") - case object R2 extends RegressionEvalMetric("r2") - case object MeanAbsoluteError extends RegressionEvalMetric("mae") -} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala b/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala index 08d62be8f8..6485bdc95e 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala @@ -31,6 +31,7 @@ package com.salesforce.op.stages.impl.classification +import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.classification.ProbabilisticClassifierType.ProbClassifier import com.salesforce.op.stages.impl.selector._ import org.apache.spark.ml.classification._ @@ -43,7 +44,7 @@ import scala.reflect.ClassTag /** * Enumeration of possible classification models in Model Selector */ -sealed trait ClassificationModelsToTry extends EnumEntry with Serializable +sealed trait ClassificationModelsToTry extends ModelsToTry object ClassificationModelsToTry extends Enum[ClassificationModelsToTry] { val values = findValues diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/package.scala b/core/src/main/scala/com/salesforce/op/stages/impl/package.scala new file mode 100644 index 0000000000..11db127f50 --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/stages/impl/package.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages + +import enumeratum.EnumEntry + +package object impl { + + /** + * Enumeration of possible models in Model Selectors + */ + trait ModelsToTry extends EnumEntry with Serializable + +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala b/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala index 300aa121a9..84080e8dcc 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala @@ -31,6 +31,7 @@ package com.salesforce.op.stages.impl.regression +import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.regression.RegressorType._ import com.salesforce.op.stages.impl.selector._ import org.apache.spark.ml.param.{BooleanParam, Param, Params} @@ -44,7 +45,7 @@ import scala.reflect.ClassTag /** * Enumeration of possible regression models in Model Selector */ -sealed trait RegressionModelsToTry extends EnumEntry with Serializable +sealed trait RegressionModelsToTry extends ModelsToTry object RegressionModelsToTry extends Enum[RegressionModelsToTry] { val values = findValues diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index a523152f30..8069c87d0d 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -31,14 +31,16 @@ package com.salesforce.op +import com.salesforce.op.evaluators.{BinaryClassEvalMetrics, BinaryClassificationMetrics} import com.salesforce.op.features.Feature -import com.salesforce.op.features.types.{FeatureTypeDefaults, PickList, Real, RealNN} +import com.salesforce.op.features.types.{PickList, Real, RealNN} import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector -import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.LogisticRegression +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{LogisticRegression, NaiveBayes} import com.salesforce.op.stages.impl.preparators._ import com.salesforce.op.stages.impl.regression.RegressionModelSelector import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.LinearRegression import com.salesforce.op.stages.impl.selector.SelectedModel +import com.salesforce.op.stages.impl.tuning.DataSplitter import com.salesforce.op.test.PassengerSparkFixtureTest import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import org.junit.runner.RunWith @@ -55,11 +57,11 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { private val descrVec = description.vectorize(10, false, 1, true) private val features = Seq(density, age, generVec, weight, descrVec).transmogrify() private val label = survived.occurs() - private val checked = label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false, - checkSample = 1.0) + private val checked = + label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false, checkSample = 1.0) val (pred, rawPred, prob) = BinaryClassificationModelSelector - .withCrossValidation(seed = 42, splitter = None) + .withCrossValidation(seed = 42, splitter = Option(DataSplitter(seed = 42, reserveTestFraction = 0.1))) .setModelsToTry(LogisticRegression) .setLogisticRegressionRegParam(0.01, 0.1) .setInput(label, checked) @@ -119,7 +121,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.label.rawFeatureName shouldBe Seq(survived.name) insights.label.rawFeatureType shouldBe Seq(survived.typeName) insights.label.stagesApplied.size shouldBe 1 - insights.label.sampleSize shouldBe Some(6.0) + insights.label.sampleSize shouldBe Some(4.0) insights.features.size shouldBe 5 insights.features.map(_.featureName).toSet shouldEqual rawNames ageInsights.derivedFeatures.size shouldBe 2 @@ -170,7 +172,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.label.rawFeatureName shouldBe Seq(survived.name) insights.label.rawFeatureType shouldBe Seq(survived.typeName) insights.label.stagesApplied.size shouldBe 1 - insights.label.sampleSize shouldBe Some(6.0) + insights.label.sampleSize shouldBe Some(4.0) insights.features.size shouldBe 5 insights.features.map(_.featureName).toSet shouldEqual rawNames ageInsights.derivedFeatures.size shouldBe 2 @@ -237,6 +239,48 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { lin.head.size shouldBe OpVectorMetadata("", checked.originStage.getMetadata()).columns.length } + it should "return best model information" in { + val insights = workflowModel.modelInsights(prob) + insights.selectedModelUID should startWith("logreg_") + insights.selectedModelName should startWith("logreg_") + insights.selectedModelType shouldBe LogisticRegression + val bestModelValidationResults = insights.selectedModelValidationResults + bestModelValidationResults.size shouldBe 15 + bestModelValidationResults.get(BinaryClassEvalMetrics.AuPR.humanFriendlyName) shouldBe Some("0.0") + val validationResults = insights.validationResults + validationResults.size shouldBe 2 + validationResults.get(insights.selectedModelName) shouldBe Some(bestModelValidationResults) + insights.validationResults(LogisticRegression) shouldBe validationResults + insights.validationResults(NaiveBayes) shouldBe Map.empty + } + + it should "return test/train evaluation metrics" in { + val insights = workflowModel.modelInsights(prob) + insights.evaluationMetricType shouldBe BinaryClassEvalMetrics.AuPR + insights.validationType shouldBe ValidationType.CrossValidation + insights.validatedModelTypes shouldBe Set(LogisticRegression) + + insights.problemType shouldBe ProblemType.BinaryClassification + insights.selectedModelTrainEvalMetrics shouldBe + BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, + Seq(0.0), Seq(0.0), Seq(0.0), Seq(1.0)) + insights.selectedModelTestEvalMetrics shouldBe Some( + BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.5, 0.75, 0.5, 0.0, 1.0, 0.0, 1.0, + Seq(0.0), Seq(0.5), Seq(1.0), Seq(1.0)) + ) + } + + it should "pretty print" in { + val insights = workflowModel.modelInsights(prob) + val pretty = insights.prettyPrint() + pretty should include(s"Selected Model - $LogisticRegression") + pretty should include("| area under PR | 0.0") + pretty should include("Model Evaluation Metrics") + pretty should include("Top Model Insights") + pretty should include("Top Positive Correlations") + pretty should include("Top Contributions") + } + it should "correctly serialize and deserialize from json" in { val insights = workflowModel.modelInsights(prob) ModelInsights.fromJson(insights.toJson()) match { diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index 9c5708ea15..40e7f74ca0 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -41,17 +41,16 @@ import com.salesforce.op.stages.base.unary._ import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry._ import com.salesforce.op.stages.impl.classification._ import com.salesforce.op.stages.impl.preparators.SanityChecker -import com.salesforce.op.stages.impl.regression.{LossType, RegressionModelSelector, RegressionModelsToTry} -import com.salesforce.op.stages.impl.selector.{ModelSelectorBaseNames, SelectedModel} +import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames import com.salesforce.op.stages.impl.tuning._ -import com.salesforce.op.test.{Passenger, PassengerCSV, PassengerSparkFixtureTest, TestFeatureBuilder} +import com.salesforce.op.test.{Passenger, PassengerSparkFixtureTest, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import org.apache.spark.ml.param.BooleanParam import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DoubleType, StringType} import org.apache.spark.sql.{Dataset, SparkSession} -import org.joda.time.{DateTime, Duration} +import org.joda.time.DateTime import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @@ -391,12 +390,21 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { val summary = fittedWorkflow.summary() log.info(summary) - summary.contains(classOf[SanityChecker].getSimpleName) shouldBe true - summary.contains("logreg") shouldBe true - summary.contains(""""regParam" : "0.1"""") shouldBe true - summary.contains(""""regParam" : "0.01"""") shouldBe true - summary.contains(ModelSelectorBaseNames.HoldOutEval) shouldBe true - summary.contains(ModelSelectorBaseNames.TrainingEval) shouldBe true + summary should include(classOf[SanityChecker].getSimpleName) + summary should include("logreg") + summary should include(""""regParam" : "0.1"""") + summary should include(""""regParam" : "0.01"""") + summary should include(ModelSelectorBaseNames.HoldOutEval) + summary should include(ModelSelectorBaseNames.TrainingEval) + + val prettySummary = fittedWorkflow.summaryPretty() + log.info(prettySummary) + prettySummary should include(s"Selected Model - $LogisticRegression") + prettySummary should include("| area under PR | 0.25") + prettySummary should include("Model Evaluation Metrics") + prettySummary should include("Top Model Insights") + prettySummary should include("Top Positive Correlations") + prettySummary should include("Top Contributions") } it should "be able to refit a workflow with calibrated probability" in { diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala new file mode 100644 index 0000000000..2d0fc817f3 --- /dev/null +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -0,0 +1,158 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.utils.table + +import com.twitter.algebird.{Monoid, Semigroup} +import enumeratum._ + + +object Table { + /** + * Simple factory for creating table instance with rows of [[Product]] types + * + * @param columns non empty sequence of column names + * @param rows non empty sequence of rows + * @param name table name + * @tparam T row type of [[Product]] + */ + def apply[T <: Product](columns: Seq[String], rows: Seq[T], name: String = ""): Table = { + require(columns.nonEmpty, "columns cannot be empty") + require(rows.nonEmpty, "rows cannot be empty") + require(columns.length == rows.head.productArity, + s"columns length must match rows arity (${columns.length}!=${rows.head.productArity})") + val rowVals = rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) + new Table(columns, rowVals, name) + } + + private implicit val max = Semigroup.from[Int](math.max) + private implicit val monoid: Monoid[Array[Int]] = Monoid.arrayMonoid[Int] +} + +/** + * Simple table representation consisting of rows, i.e: + * + * +----------------------------------------+ + * | Transactions | + * +----------------------------------------+ + * | date | amount | source | status | + * +------+--------+--------------+---------+ + * | 1 | 4.95 | Cafe Venetia | Success | + * | 2 | 12.65 | Sprout | Success | + * | 3 | 4.75 | Caltrain | Pending | + * +------+--------+--------------+---------+ + * + * @param columns non empty sequence of column names + * @param rows non empty sequence of rows + * @param name table name + */ +class Table private(columns: Seq[String], rows: Seq[Seq[String]], name: String) { + private def formatCell(v: String, size: Int, sep: String, fill: String): PartialFunction[Alignment, String] = { + case Alignment.Left => v + fill * (size - v.length) + case Alignment.Right => fill * (size - v.length) + v + case Alignment.Center => + String.format("%-" + size + "s", String.format("%" + (v.length + (size - v.length) / 2) + "s", v)) + } + + private def formatRow( + values: Iterable[String], + cellSizes: Iterable[Int], + alignment: String => Alignment, + sep: String = "|", + fill: String = " " + ): String = { + val formatted = values.zipWithIndex.zip(cellSizes).map { case ((v, i), size) => + formatCell(v, size, sep, fill)(alignment(columns(i))) + } + formatted.mkString(s"$sep$fill", s"$fill$sep$fill", s"$fill$sep") + } + + private def sortColumns(ascending: Boolean): Table = { + val (columnsSorted, indices) = columns.zipWithIndex.sortBy(_._1).unzip + val rowsSorted = rows.map(row => row.zip(indices).sortBy(_._2).unzip._1) + new Table( + columns = if (ascending) columnsSorted else columnsSorted.reverse, + rows = if (ascending) rowsSorted else rowsSorted.map(_.reverse), + name = name + ) + } + + /** + * Sort table columns in alphabetical order + */ + def sortColumnsAsc: Table = sortColumns(ascending = true) + + /** + * Sort table columns in inverse alphabetical order + */ + def sortColumnsDesc: Table = sortColumns(ascending = false) + + /** + * Pretty print table + * + * @param nameAlignment table name alignment + * @param columnAlignments column name & values alignment + * @param defaultColumnAlignment default column name & values alignment + * @return pretty printed table + */ + def prettyString( + nameAlignment: Alignment = Alignment.Center, + columnAlignments: Map[String, Alignment] = Map.empty, + defaultColumnAlignment: Alignment = Alignment.Left + ): String = { + val columnSizes = columns.map(c => math.max(1, c.length)).toArray + val cellSizes = rows.map(_.map(_.length).toArray).foldLeft(columnSizes)(Table.monoid.plus) + val bracket = formatRow(Seq.fill(cellSizes.length)(""), cellSizes, _ => Alignment.Left, sep = "+", fill = "-") + val rowWidth = bracket.length - 4 + val cleanBracket = formatRow(Seq(""), Seq(rowWidth), _ => Alignment.Left, sep = "+", fill = "-") + val maybeName = Option(name) match { + case Some(n) if n.nonEmpty => Seq(cleanBracket, formatRow(Seq(name), Seq(rowWidth), _ => nameAlignment)) + case _ => Seq.empty + } + val alignment: String => Alignment = columnAlignments.getOrElse(_, defaultColumnAlignment) + val columnsHeader = formatRow(columns, cellSizes, alignment) + val formattedRows = rows.map(formatRow(_, cellSizes, alignment)) + + (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") + } + + override def toString: String = prettyString() + +} + +sealed trait Alignment extends EnumEntry +object Alignment extends Enum[Alignment] { + val values = findValues + case object Left extends Alignment + case object Right extends Alignment + case object Center extends Alignment +} + diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala new file mode 100644 index 0000000000..2b01126a70 --- /dev/null +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.utils.table + +import com.salesforce.op.test.TestCommon +import com.salesforce.op.utils.table.Alignment._ +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +case class Transaction(date: Long, amount: Double, source: String, status: String) + +@RunWith(classOf[JUnitRunner]) +class TableTest extends FlatSpec with TestCommon { + + // scalastyle:off indentation + + val columns = Seq("date", "amount", "source", "status") + val transactions = Seq( + Transaction(1, 4.95, "Cafe Venetia", "Success"), + Transaction(2, 12.65, "Sprout", "Success"), + Transaction(3, 4.75, "Caltrain", "Pending") + ) + + Spec[Table] should "error on missing columns" in { + intercept[IllegalArgumentException] { + Table(columns = Seq.empty, rows = transactions) + }.getMessage shouldBe "requirement failed: columns cannot be empty" + } + it should "error on empty rows" in { + intercept[IllegalArgumentException] { + Table(columns = columns, rows = Seq.empty[Transaction]) + }.getMessage shouldBe "requirement failed: rows cannot be empty" + } + it should "error on invalid arity" in { + intercept[IllegalArgumentException] { + Table(columns = Seq("a"), rows = transactions) + }.getMessage shouldBe "requirement failed: columns length must match rows arity (1!=4)" + } + it should "pretty print a table" in { + Table(columns = columns, rows = transactions).prettyString() shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "have a pretty toString as well" in { + val table = Table(columns = columns, rows = transactions) + table.prettyString() shouldBe table.toString + } + it should "sort columns in ascending order" in { + Table(columns = columns, rows = transactions).sortColumnsAsc.prettyString() shouldBe + """|+----------------------------------------+ + || amount | date | source | status | + |+--------+------+--------------+---------+ + || 4.95 | 1 | Cafe Venetia | Success | + || 12.65 | 2 | Sprout | Success | + || 4.75 | 3 | Caltrain | Pending | + |+--------+------+--------------+---------+""".stripMargin + } + it should "sort columns in descending order" in { + Table(columns = columns, rows = transactions).sortColumnsDesc.prettyString() shouldBe + """|+----------------------------------------+ + || status | source | date | amount | + |+---------+--------------+------+--------+ + || Success | Cafe Venetia | 1 | 4.95 | + || Success | Sprout | 2 | 12.65 | + || Pending | Caltrain | 3 | 4.75 | + |+---------+--------------+------+--------+""".stripMargin + } + it should "pretty print a table with a name" in { + Table(columns = columns, rows = transactions, name = "Transactions").prettyString() shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with a name aligned left" in { + Table(columns = columns, rows = transactions, name = "Transactions").prettyString(nameAlignment = Left) shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with right column alignment" in { + Table(columns = columns, rows = transactions).prettyString(defaultColumnAlignment = Right) shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with center column alignment" in { + Table(columns = columns, rows = transactions).prettyString(defaultColumnAlignment = Center) shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with custom column alignment" in { + Table(columns = columns, rows = transactions, name = "Transactions") + .prettyString( + nameAlignment = Center, defaultColumnAlignment = Right, + columnAlignments = Map("date" -> Right, "amount" -> Left, "status" -> Center)) shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table even if data is bad" in { + val badData1 = Seq(Tuple2(null, "one"), "2" -> "", (null, null), "3" -> Transaction(1, 1.0, "?", "?")) + Table(columns = Seq("c1", "c2"), rows = badData1, name = "Bad Data").prettyString() shouldBe + """|+-----------------------------+ + || Bad Data | + |+-----------------------------+ + || c1 | c2 | + |+----+------------------------+ + || | one | + || 2 | | + || | | + || 3 | Transaction(1,1.0,?,?) | + |+----+------------------------+""".stripMargin + } + it should "pretty print a table even if data is really bad" in { + val badData2 = Seq(null, "", 1).map(Tuple1(_)) + Table(columns = Seq(""), rows = badData2).prettyString() shouldBe + """|+---+ + || | + |+---+ + || | + || | + || 1 | + |+---+""".stripMargin + } + +}