Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model combiner #385

Merged
merged 19 commits into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 55 additions & 26 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.filters._
import com.salesforce.op.stages._
import com.salesforce.op.stages.impl.feature.{TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.feature.{CombinationStrategy, TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.selector._
import com.salesforce.op.stages.impl.tuning.{DataBalancerSummary, DataCutterSummary, DataSplitterSummary}
Expand All @@ -46,6 +46,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.table.Alignment._
import com.salesforce.op.utils.table.Table
import com.twitter.algebird.Operators._
import com.twitter.algebird.Moments
import ml.dmlc.xgboost4j.scala.spark.OpXGBoost.RichBooster
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostRegressionModel}
Expand Down Expand Up @@ -443,43 +444,69 @@ case object ModelInsights {
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureFilterResults: RawFeatureFilterResults
): ModelInsights = {
val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityChecker = sanityCheckers.lastOption
val checkerSummary = sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
log.info(
s"Found ${sanityCheckers.length} sanity checkers will " +
s"${sanityChecker.map("use results from the last checker:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

// TODO support other model types?
val models = stages.collect{
case s: SelectedModel => s
case s: OpPredictorWrapperModel[_] => s
} // TODO support other model types?
val model = models.lastOption
case s: SelectedCombinerModel => s
}
val model = models.flatMap{
case s: SelectedCombinerModel if s.strategy == CombinationStrategy.Best =>
val originF = if (s.weight1 > 0.5) s.getInputFeature[Prediction](1) else s.getInputFeature[Prediction](2)
models.find( m => originF.exists(_.originStage.uid == m.uid) )
case s => Option(s)
}.lastOption

log.info(
s"Found ${models.length} models will " +
s"${model.map("use results from the last model:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

val label = model.map(_.getInputFeature[RealNN](0)).orElse(sanityChecker.map(_.getInputFeature[RealNN](0))).flatten
val modelInputStages: Set[String] = model.map { m =>
val stages = m.getInputFeatures().map(_.parentStages().toOption.map(_.keySet.map(_.uid)))
val uid = stages.collect{ case Some(uids) => uids }
uid.fold(Set.empty)(_ + _)
}.getOrElse(Set.empty)

val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityCheckersForModel = sanityCheckers.filter(s => modelInputStages.contains(s.uid) &&
model.exists(_.getInputFeature[RealNN](0) == s.getInputFeature[RealNN](0))).toSeq

val sanityChecker = if (sanityCheckersForModel.nonEmpty) sanityCheckersForModel else sanityCheckers.lastOption.toSeq
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
val checkerSummary = if (sanityChecker.nonEmpty) {
Option(SanityCheckerSummary.flatten(
sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
))
} else None

log.info(
s"Found ${sanityCheckers.length} sanity checkers" +
s"${sanityChecker.map("will preferentially use results from checkers in model path:" + _.uid +
" to fill in model insights")}"
)


val label = model.map(_.getInputFeature[RealNN](0))
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
.orElse(sanityChecker.lastOption.map(_.getInputFeature[RealNN](0))).flatten
log.info(s"Found ${label.map(_.name + " as label").getOrElse("no label")} to fill in model insights")

// Recover the vector metadata
val vectorInput: Option[OpVectorMetadata] = {
def makeMeta(s: => OpPipelineStageParams) = Try(OpVectorMetadata(s.getInputSchema().last)).toOption

sanityChecker
// first try out to get vector metadata from sanity checker
.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s)))
// fall back to model selector stage metadata
.orElse(model.flatMap(m => makeMeta(m.parent.asInstanceOf[ModelSelector[_, _]])))
// finally try to get it from the last vector stage
.orElse(
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
if (sanityChecker.nonEmpty) { // first try out to get vector metadata from sanity checker
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
Option(OpVectorMetadata.flatten("",
sanityChecker.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s))))
)
} else {
model.flatMap(m => makeMeta(m)) // fall back to model selector stage metadata
.orElse( // finally try to get it from the last vector stage
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
}
}
log.info(
s"Found ${vectorInput.map(_.name + " as feature vector").getOrElse("no feature vector")}" +
Expand Down Expand Up @@ -639,7 +666,7 @@ case object ModelInsights {
contribution =
contributions.map(_.applyOrElse(h.index, (_: Int) => 0.0)) // nothing dropped without sanity check
)
}
}
case (None, _) => Seq.empty
}

Expand Down Expand Up @@ -712,9 +739,9 @@ case object ModelInsights {
// need to also divide by labelStd for linear regression
// See https://u.demog.berkeley.edu/~andrew/teaching/standard_coeff.pdf
// See https://en.wikipedia.org/wiki/Standardized_coefficient
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
case _ => sparkFtrContrib
}
}
Expand Down Expand Up @@ -747,6 +774,8 @@ case object ModelInsights {
model match {
case Some(m: SelectedModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case Some(m: SelectedCombinerModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,25 @@ sealed trait EvalMetric extends EnumEntry with Serializable {
*/
def humanFriendlyName: String

/**
* Is this metric being larger better or is smaller better
*/
def isLargerBetter: Boolean

}

/**
* Eval metric companion object
*/
object EvalMetric {

def withNameInsensitive(name: String): EvalMetric = {
def withNameInsensitive(name: String, isLargerBetter: Boolean = true): EvalMetric = {
BinaryClassEvalMetrics.withNameInsensitiveOption(name)
.orElse(MultiClassEvalMetrics.withNameInsensitiveOption(name))
.orElse(RegressionEvalMetrics.withNameInsensitiveOption(name))
.orElse(ForecastEvalMetrics.withNameInsensitiveOption(name))
.orElse(OpEvaluatorNames.withNameInsensitiveOption(name))
.getOrElse(OpEvaluatorNames.Custom(name, name))
.getOrElse(OpEvaluatorNames.Custom(name, name, isLargerBetter))
}
}

Expand All @@ -122,37 +127,38 @@ object EvalMetric {
sealed abstract class ClassificationEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Binary Classification Metrics
*/
object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC")
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall")
case object TP extends ClassificationEvalMetric("TP", "true positive")
case object TN extends ClassificationEvalMetric("TN", "true negative")
case object FP extends ClassificationEvalMetric("FP", "false positive")
case object FN extends ClassificationEvalMetric("FN", "false negative")
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC", true)
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall", true)
case object TP extends ClassificationEvalMetric("TP", "true positive", true)
case object TN extends ClassificationEvalMetric("TN", "true negative", true)
case object FP extends ClassificationEvalMetric("FP", "false positive", false)
case object FN extends ClassificationEvalMetric("FN", "false negative", false)
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score", false)
}

/**
* Multi Classification Metrics
*/
object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics", true)
}


Expand All @@ -162,18 +168,19 @@ object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
sealed abstract class RegressionEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Regression Metrics
*/
object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
val values: Seq[RegressionEvalMetric] = findValues
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error")
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error")
case object R2 extends RegressionEvalMetric("r2", "r2")
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error")
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error", false)
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error", false)
case object R2 extends RegressionEvalMetric("r2", "r2", true)
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error", false)
}


Expand All @@ -183,15 +190,16 @@ object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
sealed abstract class ForecastEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric


object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
val values: Seq[ForecastEvalMetric] = findValues
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error")
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error")
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error")
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error", false)
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error", false)
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error", false)
}


Expand All @@ -201,25 +209,32 @@ object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
sealed abstract class OpEvaluatorNames
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean // for default value
) extends EvalMetric

/**
* Contains evaluator names used in logging
*/
object OpEvaluatorNames extends Enum[OpEvaluatorNames] {
val values: Seq[OpEvaluatorNames] = findValues
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics")
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics")
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics")
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics")
case object Forecast extends OpEvaluatorNames("regForecast", "forecast evaluation metrics")
case class Custom(name: String, humanName: String) extends OpEvaluatorNames(name, humanName) {
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics", true)
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics", false)
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics", true)
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics", false)
case object Forecast extends OpEvaluatorNames("regForecast", "regression evaluation metrics", false)
case class Custom(name: String, humanName: String, largeBetter: Boolean) extends
OpEvaluatorNames(name, humanName, largeBetter) {
override def entryName: String = name.toLowerCase
}
def withName(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
Try(super.withName(name)).getOrElse(Custom(name, name, isLargerBetter))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid Try and exceptions in simple program flows, use super.withNameOption instead

override def withName(name: String): OpEvaluatorNames = withName(name, true)

def withNameInsensitive(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
super.withNameInsensitiveOption(name).getOrElse(Custom(name, name, isLargerBetter))
override def withNameInsensitive(name: String): OpEvaluatorNames = withNameInsensitive(name, true)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want to have the default of true for isLargerBetter value? perhaps let's avoid specifying any defaults for withName and withNameInsensitive methods.

Instead it's better to have the method:

def withNameOrDefault(name: String, default: String => OpEvaluatorNames => name => Custom(name, name, isLargerBetter)): OpEvaluatorNames = {
    super.withNameInsensitiveOption(name).getOrElse(default)
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are base methods @tovbinm , hence the override. if they want the fallback behavior your method provides they can use my definitions - but what happens when people call the base methods and there is no default?


def withFriendlyNameInsensitive(name: String): Option[OpEvaluatorNames] =
values.collectFirst { case n if n.humanFriendlyName.equalsIgnoreCase(name) => n }
override def withName(name: String): OpEvaluatorNames = Try(super.withName(name)).getOrElse(Custom(name, name))
override def withNameInsensitive(name: String): OpEvaluatorNames = super.withNameInsensitiveOption(name)
.getOrElse(Custom(name, name))
}
Loading