Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model combiner #385

Merged
merged 19 commits into from
Sep 3, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adressing comments
  • Loading branch information
leahmcguire committed Aug 29, 2019
commit 6139bd054b58a386dbcfa7e78a6eb9e9eb814894
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,10 @@ object OpEvaluatorNames extends Enum[OpEvaluatorNames] {
override def entryName: String = name.toLowerCase
}
def withName(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
Try(super.withName(name)).getOrElse(Custom(name, name, isLargerBetter))
override def withName(name: String): OpEvaluatorNames = withName(name, true)
super.withNameOption(name).getOrElse(Custom(name, name, isLargerBetter))

def withNameInsensitive(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
super.withNameInsensitiveOption(name).getOrElse(Custom(name, name, isLargerBetter))
override def withNameInsensitive(name: String): OpEvaluatorNames = withNameInsensitive(name, true)

def withFriendlyNameInsensitive(name: String): Option[OpEvaluatorNames] =
values.collectFirst { case n if n.humanFriendlyName.equalsIgnoreCase(name) => n }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,13 @@ object CorrelationType extends Enum[CorrelationType] {
* @see https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient
*/
case object Spearman extends CorrelationType("spearman")

/**
* Compute with Spearman's rank-order correlation
*
* @see https://en.wikipedia.org/wiki/Spearman%27s_rank_correlation_coefficient
*/
case class Custom(name: String, spark: String) extends CorrelationType(spark)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,18 @@ case class Correlations
corrMeta.build()
}

private[op] def +(corr: Correlations): Correlations =
new Correlations(featuresIn ++ corr.featuresIn, values ++ corr.values, nanCorrs ++ corr.nanCorrs, corrType)
private[op] def +(corr: Correlations): Correlations = {
val corrName =
if (corrType != corr.corrType) {
CorrelationType.Custom(
corrType.entryName + corr.corrType.entryName,
corrType.sparkName + corr.corrType.sparkName
)
} else {
corrType
}
new Correlations(featuresIn ++ corr.featuresIn, values ++ corr.values, nanCorrs ++ corr.nanCorrs, corrName)
}
}

case object SanityCheckerSummary {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ trait SelectedCombinerParams extends Params {
isValid = (in: String) => CombinationStrategy.values.map(_.entryName).contains(in)
)
def setCombinationStrategy(value: CombinationStrategy): this.type = set(combinationStrategy, value.entryName)
def getCombinationStrategy(): CombinationStrategy = CombinationStrategy.namesToValuesMap($(combinationStrategy))
def getCombinationStrategy(): CombinationStrategy = CombinationStrategy.withNameInsensitive($(combinationStrategy))
setDefault(combinationStrategy, CombinationStrategy.Best.entryName)

}
Expand All @@ -68,10 +68,10 @@ trait SelectedCombinerParams extends Params {
* @param operationName name of operation
* @param uid stage uid
*/
class SelectedCombiner
class SelectedModelCombiner
(
val operationName: String = "combineModels",
val uid: String = UID[SelectedCombiner]
val uid: String = UID[SelectedModelCombiner]
)(
implicit val tto: TypeTag[Prediction],
val ttov: TypeTag[Prediction#Value]
Expand Down Expand Up @@ -116,21 +116,6 @@ class SelectedCombiner
s"Cannot combine model selectors for different problem types found ${summary1.problemType}" +
s" and ${summary2.problemType}")

def getMetricValue(metrics: EvaluationMetrics, name: EvalMetric) =
metrics.toMap.collectFirst{
case (k, v) if k.contains(name.humanFriendlyName) || k.contains(name.entryName) => v.asInstanceOf[Double]}

def getWinningModelMetric(summary: ModelSelectorSummary) = {
summary.validationResults.collectFirst {
case r if r.modelUID == summary.bestModelUID =>
getMetricValue(r.metricValues, summary.evaluationMetric)
}.flatten
}

def getMet(optionMet: Option[Double]) = optionMet.getOrElse {
throw new RuntimeException("Evaluation metrics for two model selectors are non-overlapping")
}

val eval1 = summary1.evaluationMetric
val eval2 = summary2.evaluationMetric

Expand All @@ -147,34 +132,6 @@ class SelectedCombiner
} else (None, None, eval1)
}

def makeMeta(model: SelectedCombinerModel): Unit = {
def updateKeys(map: Map[String, Any], string: String) = map.map{ case (k, v) => k + string -> v }

if (model.strategy == CombinationStrategy.Best && model.weight1 > 0.5) {
setMetadata(summary1.toMetadata().toSummaryMetadata())
} else if (model.strategy == CombinationStrategy.Best) {
setMetadata(summary2.toMetadata().toSummaryMetadata())
} else {
val summary = new ModelSelectorSummary(
validationType = summary1.validationType,
validationParameters = updateKeys(summary1.validationParameters, "_1") ++
updateKeys(summary2.validationParameters, "_2"),
dataPrepParameters = updateKeys(summary1.dataPrepParameters, "_1") ++
updateKeys(summary2.dataPrepParameters, "_2"),
dataPrepResults = summary1.dataPrepResults.orElse(summary2.dataPrepResults),
evaluationMetric = metricName,
problemType = summary1.problemType,
bestModelUID = summary1.bestModelUID + " " + summary2.bestModelUID,
bestModelName = summary1.bestModelName + " " + summary2.bestModelName,
bestModelType = summary1.bestModelType + " " + summary2.bestModelType,
validationResults = summary1.validationResults ++ summary2.validationResults,
trainEvaluation = evaluate(model.transform(dataset)),
holdoutEvaluation = None
)
setMetadata(summary.toMetadata().toSummaryMetadata())
}
}

val (metricValue1, metricValue2) = (getMet(metricValueOpt1), getMet(metricValueOpt2))

val strategy = getCombinationStrategy()
Expand Down Expand Up @@ -202,10 +159,51 @@ class SelectedCombiner
.setInput(in1.asFeatureLike[RealNN], in2.asFeatureLike[Prediction], in3.asFeatureLike[Prediction])
.setOutputFeatureName(getOutputFeatureName)

makeMeta(model)
if (model.strategy == CombinationStrategy.Best && model.weight1 > 0.5) {
setMetadata(summary1.toMetadata().toSummaryMetadata())
} else if (model.strategy == CombinationStrategy.Best) {
setMetadata(summary2.toMetadata().toSummaryMetadata())
} else {
val summary = new ModelSelectorSummary(
validationType = summary1.validationType,
validationParameters = updateKeys(summary1.validationParameters, "_1") ++
updateKeys(summary2.validationParameters, "_2"),
dataPrepParameters = updateKeys(summary1.dataPrepParameters, "_1") ++
updateKeys(summary2.dataPrepParameters, "_2"),
dataPrepResults = summary1.dataPrepResults.orElse(summary2.dataPrepResults),
evaluationMetric = metricName,
problemType = summary1.problemType,
bestModelUID = summary1.bestModelUID + " " + summary2.bestModelUID,
bestModelName = summary1.bestModelName + " " + summary2.bestModelName,
bestModelType = summary1.bestModelType + " " + summary2.bestModelType,
validationResults = summary1.validationResults ++ summary2.validationResults,
trainEvaluation = evaluate(model.transform(dataset)),
holdoutEvaluation = None
)
setMetadata(summary.toMetadata().toSummaryMetadata())
}

model.setMetadata(getMetadata())
}

private def getMetricValue(metrics: EvaluationMetrics, name: EvalMetric) =
metrics.toMap.collectFirst{
case (k, v: Double) if k.contains(name.humanFriendlyName) || k.contains(name.entryName) => v
}

private def getWinningModelMetric(summary: ModelSelectorSummary) = {
summary.validationResults.collectFirst {
case r if r.modelUID == summary.bestModelUID =>
getMetricValue(r.metricValues, summary.evaluationMetric)
}.flatten
}

private def getMet(optionMet: Option[Double]) = optionMet.getOrElse {
throw new RuntimeException("Evaluation metrics for two model selectors are non-overlapping")
}

private def updateKeys(map: Map[String, Any], string: String) = map.map{ case (k, v) => k + string -> v }

}

final class SelectedCombinerModel private[op]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import com.salesforce.op.stages.impl.classification._
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, OpXGBoostRegressor, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.stages.impl.selector.{SelectedCombiner, SelectedCombinerModel, SelectedModel}
import com.salesforce.op.stages.impl.selector.{SelectedModelCombiner, SelectedCombinerModel, SelectedModel}
import com.salesforce.op.stages.impl.selector.ValidationType._
import com.salesforce.op.stages.impl.tuning.{DataCutter, DataSplitter}
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestFeatureBuilder}
Expand Down Expand Up @@ -802,7 +802,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
}

it should "return correct insights when a model combiner equal is used as the final feature" in {
val predComb = new SelectedCombiner().setCombinationStrategy(CombinationStrategy.Equal)
val predComb = new SelectedModelCombiner().setCombinationStrategy(CombinationStrategy.Equal)
.setInput(label, pred, predWithMaps).getOutput()
val workflowModel = new OpWorkflow().setResultFeatures(pred, predComb)
.setParameters(params).setReader(dataReader).train()
Expand All @@ -816,7 +816,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest with Dou
}

it should "return correct insights when a model combiner best is used as the final feature" in {
val predComb = new SelectedCombiner().setCombinationStrategy(CombinationStrategy.Best)
val predComb = new SelectedModelCombiner().setCombinationStrategy(CombinationStrategy.Best)
.setInput(label, pred, predWithMaps).getOutput()
val workflowModel = new OpWorkflow().setResultFeatures(pred, predComb)
.setParameters(params).setReader(dataReader).train()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SelectedCombinerTest extends OpEstimatorSpec[Prediction, SelectedCombinerModel, SelectedCombiner]
class SelectedModelCombinerTest extends OpEstimatorSpec[Prediction, SelectedCombinerModel, SelectedModelCombiner]
with PredictionEquality {

val (seed, smallCount, bigCount) = (1234L, 20, 80)
Expand Down Expand Up @@ -96,7 +96,7 @@ class SelectedCombinerTest extends OpEstimatorSpec[Prediction, SelectedCombinerM
.setResultFeatures(ms1, ms2)
.transform(data)

override val estimator: SelectedCombiner = new SelectedCombiner().setInput(label, ms1, ms2)
override val estimator: SelectedModelCombiner = new SelectedModelCombiner().setInput(label, ms1, ms2)

override val expectedResult: Seq[Prediction] = inputData.collect(ms1)

Expand Down Expand Up @@ -140,7 +140,7 @@ class SelectedCombinerTest extends OpEstimatorSpec[Prediction, SelectedCombinerM
.setResultFeatures(ms1, ms2)
.transform(data)

val comb = new SelectedCombiner().setInput(label, ms1, ms2)
val comb = new SelectedModelCombiner().setInput(label, ms1, ms2)
val combFit = comb.fit(inputData)
combFit.transform(inputData).collect(comb.getOutput()) shouldBe inputData.collect(ms1)
combFit.strategy shouldBe CombinationStrategy.Best
Expand Down