Skip to content

Commit

Permalink
Model combiner (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Sep 3, 2019
1 parent b91ffe3 commit 51037a8
Show file tree
Hide file tree
Showing 26 changed files with 680 additions and 128 deletions.
81 changes: 55 additions & 26 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.filters._
import com.salesforce.op.stages._
import com.salesforce.op.stages.impl.feature.{TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.feature.{CombinationStrategy, TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.selector._
import com.salesforce.op.stages.impl.tuning.{DataBalancerSummary, DataCutterSummary, DataSplitterSummary}
Expand All @@ -46,6 +46,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.table.Alignment._
import com.salesforce.op.utils.table.Table
import com.twitter.algebird.Operators._
import com.twitter.algebird.Moments
import ml.dmlc.xgboost4j.scala.spark.OpXGBoost.RichBooster
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostRegressionModel}
Expand Down Expand Up @@ -443,43 +444,69 @@ case object ModelInsights {
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureFilterResults: RawFeatureFilterResults
): ModelInsights = {
val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityChecker = sanityCheckers.lastOption
val checkerSummary = sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
log.info(
s"Found ${sanityCheckers.length} sanity checkers will " +
s"${sanityChecker.map("use results from the last checker:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

// TODO support other model types?
val models = stages.collect{
case s: SelectedModel => s
case s: OpPredictorWrapperModel[_] => s
} // TODO support other model types?
val model = models.lastOption
case s: SelectedCombinerModel => s
}
val model = models.flatMap{
case s: SelectedCombinerModel if s.strategy == CombinationStrategy.Best =>
val originF = if (s.weight1 > 0.5) s.getInputFeature[Prediction](1) else s.getInputFeature[Prediction](2)
models.find( m => originF.exists(_.originStage.uid == m.uid) )
case s => Option(s)
}.lastOption

log.info(
s"Found ${models.length} models will " +
s"${model.map("use results from the last model:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

val label = model.map(_.getInputFeature[RealNN](0)).orElse(sanityChecker.map(_.getInputFeature[RealNN](0))).flatten
val modelInputStages: Set[String] = model.map { m =>
val stages = m.getInputFeatures().map(_.parentStages().toOption.map(_.keySet.map(_.uid)))
val uid = stages.collect{ case Some(uids) => uids }
uid.fold(Set.empty)(_ + _)
}.getOrElse(Set.empty)

val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityCheckersForModel = sanityCheckers.filter(s => modelInputStages.contains(s.uid) &&
model.exists(_.getInputFeature[RealNN](0) == s.getInputFeature[RealNN](0))).toSeq

val sanityChecker = if (sanityCheckersForModel.nonEmpty) sanityCheckersForModel else sanityCheckers.lastOption.toSeq
val checkerSummary = if (sanityChecker.nonEmpty) {
Option(SanityCheckerSummary.flatten(
sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
))
} else None

log.info(
s"Found ${sanityCheckers.length} sanity checkers" +
s"${sanityChecker.map("will preferentially use results from checkers in model path:" + _.uid +
" to fill in model insights")}"
)


val label = model.map(_.getInputFeature[RealNN](0))
.orElse(sanityChecker.lastOption.map(_.getInputFeature[RealNN](0))).flatten
log.info(s"Found ${label.map(_.name + " as label").getOrElse("no label")} to fill in model insights")

// Recover the vector metadata
val vectorInput: Option[OpVectorMetadata] = {
def makeMeta(s: => OpPipelineStageParams) = Try(OpVectorMetadata(s.getInputSchema().last)).toOption

sanityChecker
// first try out to get vector metadata from sanity checker
.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s)))
// fall back to model selector stage metadata
.orElse(model.flatMap(m => makeMeta(m.parent.asInstanceOf[ModelSelector[_, _]])))
// finally try to get it from the last vector stage
.orElse(
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
if (sanityChecker.nonEmpty) { // first try out to get vector metadata from sanity checker
Option(OpVectorMetadata.flatten("",
sanityChecker.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s))))
)
} else {
model.flatMap(m => makeMeta(m)) // fall back to model selector stage metadata
.orElse( // finally try to get it from the last vector stage
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
}
}
log.info(
s"Found ${vectorInput.map(_.name + " as feature vector").getOrElse("no feature vector")}" +
Expand Down Expand Up @@ -639,7 +666,7 @@ case object ModelInsights {
contribution =
contributions.map(_.applyOrElse(h.index, (_: Int) => 0.0)) // nothing dropped without sanity check
)
}
}
case (None, _) => Seq.empty
}

Expand Down Expand Up @@ -712,9 +739,9 @@ case object ModelInsights {
// need to also divide by labelStd for linear regression
// See https://u.demog.berkeley.edu/~andrew/teaching/standard_coeff.pdf
// See https://en.wikipedia.org/wiki/Standardized_coefficient
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
case _ => sparkFtrContrib
}
}
Expand Down Expand Up @@ -747,6 +774,8 @@ case object ModelInsights {
model match {
case Some(m: SelectedModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case Some(m: SelectedCombinerModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,25 @@ sealed trait EvalMetric extends EnumEntry with Serializable {
*/
def humanFriendlyName: String

/**
* Is this metric being larger better or is smaller better
*/
def isLargerBetter: Boolean

}

/**
* Eval metric companion object
*/
object EvalMetric {

def withNameInsensitive(name: String): EvalMetric = {
def withNameInsensitive(name: String, isLargerBetter: Boolean = true): EvalMetric = {
BinaryClassEvalMetrics.withNameInsensitiveOption(name)
.orElse(MultiClassEvalMetrics.withNameInsensitiveOption(name))
.orElse(RegressionEvalMetrics.withNameInsensitiveOption(name))
.orElse(ForecastEvalMetrics.withNameInsensitiveOption(name))
.orElse(OpEvaluatorNames.withNameInsensitiveOption(name))
.getOrElse(OpEvaluatorNames.Custom(name, name))
.getOrElse(OpEvaluatorNames.Custom(name, name, isLargerBetter))
}
}

Expand All @@ -122,37 +127,38 @@ object EvalMetric {
sealed abstract class ClassificationEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Binary Classification Metrics
*/
object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC")
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall")
case object TP extends ClassificationEvalMetric("TP", "true positive")
case object TN extends ClassificationEvalMetric("TN", "true negative")
case object FP extends ClassificationEvalMetric("FP", "false positive")
case object FN extends ClassificationEvalMetric("FN", "false negative")
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC", true)
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall", true)
case object TP extends ClassificationEvalMetric("TP", "true positive", true)
case object TN extends ClassificationEvalMetric("TN", "true negative", true)
case object FP extends ClassificationEvalMetric("FP", "false positive", false)
case object FN extends ClassificationEvalMetric("FN", "false negative", false)
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score", false)
}

/**
* Multi Classification Metrics
*/
object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics", true)
}


Expand All @@ -162,18 +168,19 @@ object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
sealed abstract class RegressionEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Regression Metrics
*/
object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
val values: Seq[RegressionEvalMetric] = findValues
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error")
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error")
case object R2 extends RegressionEvalMetric("r2", "r2")
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error")
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error", false)
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error", false)
case object R2 extends RegressionEvalMetric("r2", "r2", true)
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error", false)
}


Expand All @@ -183,15 +190,16 @@ object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
sealed abstract class ForecastEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric


object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
val values: Seq[ForecastEvalMetric] = findValues
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error")
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error")
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error")
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error", false)
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error", false)
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error", false)
}


Expand All @@ -201,25 +209,30 @@ object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
sealed abstract class OpEvaluatorNames
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean // for default value
) extends EvalMetric

/**
* Contains evaluator names used in logging
*/
object OpEvaluatorNames extends Enum[OpEvaluatorNames] {
val values: Seq[OpEvaluatorNames] = findValues
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics")
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics")
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics")
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics")
case object Forecast extends OpEvaluatorNames("regForecast", "forecast evaluation metrics")
case class Custom(name: String, humanName: String) extends OpEvaluatorNames(name, humanName) {
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics", true)
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics", false)
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics", true)
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics", false)
case object Forecast extends OpEvaluatorNames("regForecast", "regression evaluation metrics", false)
case class Custom(name: String, humanName: String, largeBetter: Boolean) extends
OpEvaluatorNames(name, humanName, largeBetter) {
override def entryName: String = name.toLowerCase
}
def withName(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
super.withNameOption(name).getOrElse(Custom(name, name, isLargerBetter))

def withNameInsensitive(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
super.withNameInsensitiveOption(name).getOrElse(Custom(name, name, isLargerBetter))

def withFriendlyNameInsensitive(name: String): Option[OpEvaluatorNames] =
values.collectFirst { case n if n.humanFriendlyName.equalsIgnoreCase(name) => n }
override def withName(name: String): OpEvaluatorNames = Try(super.withName(name)).getOrElse(Custom(name, name))
override def withNameInsensitive(name: String): OpEvaluatorNames = super.withNameInsensitiveOption(name)
.getOrElse(Custom(name, name))
}
Loading

0 comments on commit 51037a8

Please sign in to comment.