Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model combiner #385

Merged
merged 19 commits into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 54 additions & 25 deletions core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import com.salesforce.op.utils.spark.RichMetadata._
import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata}
import com.salesforce.op.utils.table.Alignment._
import com.salesforce.op.utils.table.Table
import com.twitter.algebird.Operators._
import com.twitter.algebird.Moments
import ml.dmlc.xgboost4j.scala.spark.OpXGBoost.RichBooster
import ml.dmlc.xgboost4j.scala.spark.{XGBoostClassificationModel, XGBoostRegressionModel}
Expand Down Expand Up @@ -443,43 +444,69 @@ case object ModelInsights {
blacklistedMapKeys: Map[String, Set[String]],
rawFeatureFilterResults: RawFeatureFilterResults
): ModelInsights = {
val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityChecker = sanityCheckers.lastOption
val checkerSummary = sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
log.info(
s"Found ${sanityCheckers.length} sanity checkers will " +
s"${sanityChecker.map("use results from the last checker:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

// TODO support other model types?
val models = stages.collect{
case s: SelectedModel => s
case s: OpPredictorWrapperModel[_] => s
} // TODO support other model types?
val model = models.lastOption
case s: SelectedCombinerModel => s
}
val model = models.flatMap{
case s: SelectedCombinerModel if s.strategy == CombinationStrategy.Best =>
val originF = if (s.weight1 > 0.5) s.getInputFeature[Prediction](1) else s.getInputFeature[Prediction](2)
models.find( m => originF.exists(_.originStage.uid == m.uid) )
case s => Option(s)
}.lastOption

log.info(
s"Found ${models.length} models will " +
s"${model.map("use results from the last model:" + _.uid + "to").getOrElse("not")}" +
s" to fill in model insights"
)

val label = model.map(_.getInputFeature[RealNN](0)).orElse(sanityChecker.map(_.getInputFeature[RealNN](0))).flatten
val modelInputStages: Set[String] = model.map { m =>
val stages = m.getInputFeatures().map(_.parentStages().toOption.map(_.keySet.map(_.uid)))
val uid = stages.collect{ case Some(uids) => uids }
uid.fold(Set.empty)(_ + _)
}.getOrElse(Set.empty)

val sanityCheckers = stages.collect { case s: SanityCheckerModel => s }
val sanityCheckersForModel = sanityCheckers.filter(s => modelInputStages.contains(s.uid) &&
model.exists(_.getInputFeature[RealNN](0) == s.getInputFeature[RealNN](0))).toSeq

val sanityChecker = if (sanityCheckersForModel.nonEmpty) sanityCheckersForModel else sanityCheckers.lastOption.toSeq
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
val checkerSummary = if (sanityChecker.nonEmpty) {
Option(SanityCheckerSummary.flatten(
sanityChecker.map(s => SanityCheckerSummary.fromMetadata(s.getMetadata().getSummaryMetadata()))
))
} else None

log.info(
s"Found ${sanityCheckers.length} sanity checkers" +
s"${sanityChecker.map("will preferentially use results from checkers in model path:" + _.uid +
" to fill in model insights")}"
)


val label = model.map(_.getInputFeature[RealNN](0))
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
.orElse(sanityChecker.lastOption.map(_.getInputFeature[RealNN](0))).flatten
log.info(s"Found ${label.map(_.name + " as label").getOrElse("no label")} to fill in model insights")

// Recover the vector metadata
val vectorInput: Option[OpVectorMetadata] = {
def makeMeta(s: => OpPipelineStageParams) = Try(OpVectorMetadata(s.getInputSchema().last)).toOption

sanityChecker
// first try out to get vector metadata from sanity checker
.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s)))
// fall back to model selector stage metadata
.orElse(model.flatMap(m => makeMeta(m.parent.asInstanceOf[ModelSelector[_, _]])))
// finally try to get it from the last vector stage
.orElse(
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
if (sanityChecker.nonEmpty) { // first try out to get vector metadata from sanity checker
leahmcguire marked this conversation as resolved.
Show resolved Hide resolved
Option(OpVectorMetadata.flatten("",
sanityChecker.flatMap(s => makeMeta(s.parent.asInstanceOf[SanityChecker]).orElse(makeMeta(s))))
)
} else {
model.flatMap(m => makeMeta(m)) // fall back to model selector stage metadata
.orElse( // finally try to get it from the last vector stage
stages.filter(_.getOutput().isSubtypeOf[OPVector]).lastOption
.map(v => OpVectorMetadata(v.getOutputFeatureName, v.getMetadata()))
)
}
}
log.info(
s"Found ${vectorInput.map(_.name + " as feature vector").getOrElse("no feature vector")}" +
Expand Down Expand Up @@ -639,7 +666,7 @@ case object ModelInsights {
contribution =
contributions.map(_.applyOrElse(h.index, (_: Int) => 0.0)) // nothing dropped without sanity check
)
}
}
case (None, _) => Seq.empty
}

Expand Down Expand Up @@ -712,9 +739,9 @@ case object ModelInsights {
// need to also divide by labelStd for linear regression
// See https://u.demog.berkeley.edu/~andrew/teaching/standard_coeff.pdf
// See https://en.wikipedia.org/wiki/Standardized_coefficient
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
sparkFtrContrib.map(_ * featureStd / labelStd)
}
else sparkFtrContrib
case _ => sparkFtrContrib
}
}
Expand Down Expand Up @@ -747,6 +774,8 @@ case object ModelInsights {
model match {
case Some(m: SelectedModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case Some(m: SelectedCombinerModel) =>
Try(ModelSelectorSummary.fromMetadata(m.getMetadata().getSummaryMetadata())).toOption
case _ => None
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ case class SummaryStatistics
meta.build()
}

private[op] def +(sum: SummaryStatistics): SummaryStatistics = new SummaryStatistics(count, sampleFraction,
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
tovbinm marked this conversation as resolved.
Show resolved Hide resolved
max ++ sum.max, min ++ sum.min, mean ++ sum.mean, variance ++ sum.variance)

}

/**
Expand Down Expand Up @@ -313,10 +316,22 @@ case class Correlations
corrMeta.putString(SanityCheckerNames.CorrelationType, corrType.sparkName)
corrMeta.build()
}

private[op] def +(corr: Correlations): Correlations =
new Correlations(featuresIn ++ corr.featuresIn, values ++ corr.values, nanCorrs ++ corr.nanCorrs, corrType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the Correlations have to have the same correlation type for this to make sense?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes- but I am not sure it makes sense to error out on that... should I maybe log a warning?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added custom to reflect if have multiple corr types

}

case object SanityCheckerSummary {

def flatten(checkers: Seq[SanityCheckerSummary]): SanityCheckerSummary = {
val correlationsWLabel: Correlations = checkers.map(_.correlationsWLabel).reduce(_ + _)
val dropped: Seq[String] = checkers.flatMap(_.dropped)
val featuresStatistics: SummaryStatistics = checkers.map(_.featuresStatistics).reduce(_ + _)
val names: Seq[String] = checkers.flatMap(_.names)
val categoricalStats: Array[CategoricalGroupStats] = checkers.flatMap(_.categoricalStats).toArray
new SanityCheckerSummary(correlationsWLabel, dropped, featuresStatistics, names, categoricalStats)
}

private def correlationsFromMetadata(meta: Metadata): Correlations = {
val wrapped = meta.wrapped
Correlations(
Expand Down
Loading