Skip to content

Commit

Permalink
updated EinsteinAppArgs to be based on of MlAppArgs base class
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Aug 22, 2019
1 parent 4f28cca commit 63bf74b
Show file tree
Hide file tree
Showing 19 changed files with 172 additions and 109 deletions.
2 changes: 1 addition & 1 deletion core/src/main/scala/com/salesforce/op/ModelInsights.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.filters._
import com.salesforce.op.stages._
import com.salesforce.op.stages.impl.feature.{TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.feature.{CombinationStrategy, TextStats, TransmogrifierDefaults}
import com.salesforce.op.stages.impl.preparators._
import com.salesforce.op.stages.impl.selector._
import com.salesforce.op.stages.impl.tuning.{DataBalancerSummary, DataCutterSummary, DataSplitterSummary}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,20 +99,25 @@ sealed trait EvalMetric extends EnumEntry with Serializable {
*/
def humanFriendlyName: String

/**
* Is this metric being larger better or is smaller better
*/
def isLargerBetter: Boolean

}

/**
* Eval metric companion object
*/
object EvalMetric {

def withNameInsensitive(name: String): EvalMetric = {
def withNameInsensitive(name: String, isLargerBetter: Boolean = true): EvalMetric = {
BinaryClassEvalMetrics.withNameInsensitiveOption(name)
.orElse(MultiClassEvalMetrics.withNameInsensitiveOption(name))
.orElse(RegressionEvalMetrics.withNameInsensitiveOption(name))
.orElse(ForecastEvalMetrics.withNameInsensitiveOption(name))
.orElse(OpEvaluatorNames.withNameInsensitiveOption(name))
.getOrElse(OpEvaluatorNames.Custom(name, name))
.getOrElse(OpEvaluatorNames.Custom(name, name, isLargerBetter))
}
}

Expand All @@ -122,37 +127,38 @@ object EvalMetric {
sealed abstract class ClassificationEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Binary Classification Metrics
*/
object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC")
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall")
case object TP extends ClassificationEvalMetric("TP", "true positive")
case object TN extends ClassificationEvalMetric("TN", "true negative")
case object FP extends ClassificationEvalMetric("FP", "false positive")
case object FN extends ClassificationEvalMetric("FN", "false negative")
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC", true)
case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under precision-recall", true)
case object TP extends ClassificationEvalMetric("TP", "true positive", true)
case object TN extends ClassificationEvalMetric("TN", "true negative", true)
case object FP extends ClassificationEvalMetric("FP", "false positive", false)
case object FN extends ClassificationEvalMetric("FN", "false negative", false)
case object BrierScore extends ClassificationEvalMetric("brierScore", "brier score", false)
}

/**
* Multi Classification Metrics
*/
object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
val values = findValues
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision")
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall")
case object F1 extends ClassificationEvalMetric("f1", "f1")
case object Error extends ClassificationEvalMetric("accuracy", "error")
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics")
case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision", true)
case object Recall extends ClassificationEvalMetric("weightedRecall", "recall", true)
case object F1 extends ClassificationEvalMetric("f1", "f1", true)
case object Error extends ClassificationEvalMetric("accuracy", "error", false)
case object ThresholdMetrics extends ClassificationEvalMetric("thresholdMetrics", "threshold metrics", true)
}


Expand All @@ -162,18 +168,19 @@ object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] {
sealed abstract class RegressionEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric

/**
* Regression Metrics
*/
object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
val values: Seq[RegressionEvalMetric] = findValues
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error")
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error")
case object R2 extends RegressionEvalMetric("r2", "r2")
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error")
case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error", false)
case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error", false)
case object R2 extends RegressionEvalMetric("r2", "r2", true)
case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error", false)
}


Expand All @@ -183,15 +190,16 @@ object RegressionEvalMetrics extends Enum[RegressionEvalMetric] {
sealed abstract class ForecastEvalMetric
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean
) extends EvalMetric


object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
val values: Seq[ForecastEvalMetric] = findValues
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error")
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error")
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error")
case object SMAPE extends ForecastEvalMetric("smape", "symmetric mean absolute percentage error", false)
case object MASE extends ForecastEvalMetric("mase", "mean absolute scaled error", false)
case object SeasonalError extends ForecastEvalMetric("seasonalError", "seasonal error", false)
}


Expand All @@ -201,23 +209,30 @@ object ForecastEvalMetrics extends Enum[ForecastEvalMetric] {
sealed abstract class OpEvaluatorNames
(
val sparkEntryName: String,
val humanFriendlyName: String
val humanFriendlyName: String,
val isLargerBetter: Boolean // for default value
) extends EvalMetric

/**
* Contains evaluator names used in logging
*/
object OpEvaluatorNames extends Enum[OpEvaluatorNames] {
val values: Seq[OpEvaluatorNames] = findValues
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics")
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics")
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics")
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics")
case object Forecast extends OpEvaluatorNames("regForecast", "regression evaluation metrics")
case class Custom(name: String, humanName: String) extends OpEvaluatorNames(name, humanName) {
case object Binary extends OpEvaluatorNames("binEval", "binary evaluation metrics", true)
case object BinScore extends OpEvaluatorNames("binScoreEval", "bin score evaluation metrics", false)
case object Multi extends OpEvaluatorNames("multiEval", "multiclass evaluation metrics", true)
case object Regression extends OpEvaluatorNames("regEval", "regression evaluation metrics", false)
case object Forecast extends OpEvaluatorNames("regForecast", "regression evaluation metrics", false)
case class Custom(name: String, humanName: String, largeBetter: Boolean) extends
OpEvaluatorNames(name, humanName, largeBetter) {
override def entryName: String = name.toLowerCase
}
override def withName(name: String): OpEvaluatorNames = Try(super.withName(name)).getOrElse(Custom(name, name))
override def withNameInsensitive(name: String): OpEvaluatorNames = super.withNameInsensitiveOption(name)
.getOrElse(Custom(name, name))
def withName(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
Try(super.withName(name)).getOrElse(Custom(name, name, isLargerBetter))
override def withName(name: String): OpEvaluatorNames = withName(name, true)

def withNameInsensitive(name: String, isLargerBetter: Boolean): OpEvaluatorNames =
super.withNameInsensitiveOption(name).getOrElse(Custom(name, name, isLargerBetter))
override def withNameInsensitive(name: String): OpEvaluatorNames = withNameInsensitive(name, true)

}
40 changes: 17 additions & 23 deletions core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ object Evaluators {
* Area under ROC
*/
def auROC(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuROC, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuROC) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuROC, dataset, default = 0.0)
}
Expand All @@ -68,7 +68,7 @@ object Evaluators {
* Area under Precision/Recall curve
*/
def auPR(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuPR, dataset, default = 0.0)
}
Expand All @@ -78,7 +78,7 @@ object Evaluators {
*/
def precision(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(
name = BinaryClassEvalMetrics.Precision, isLargerBetter = true) {
name = BinaryClassEvalMetrics.Precision) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.Precision, dataset, default = 0.0)
}
Expand All @@ -88,7 +88,7 @@ object Evaluators {
* Recall
*/
def recall(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.Recall, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.Recall) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.Recall, dataset, default = 0.0)
}
Expand All @@ -97,7 +97,7 @@ object Evaluators {
* F1 score
*/
def f1(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.F1, isLargerBetter = true) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.F1) {
override def evaluate(dataset: Dataset[_]): Double =
getBinaryEvaluatorMetric(BinaryClassEvalMetrics.F1, dataset, default = 0.0)
}
Expand All @@ -106,7 +106,7 @@ object Evaluators {
* Prediction error
*/
def error(): OpBinaryClassificationEvaluator =
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.Error, isLargerBetter = false) {
new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.Error) {
override def evaluate(dataset: Dataset[_]): Double =
1.0 - getBinaryEvaluatorMetric(BinaryClassEvalMetrics.Error, dataset, default = 1.0)
}
Expand All @@ -128,12 +128,10 @@ object Evaluators {
isLargerBetter: Boolean = true,
evaluateFn: Dataset[(Double, OPVector#Value, OPVector#Value, Double)] => Double
): OpBinaryClassificationEvaluatorBase[SingleMetric] = {
val largerBetter = isLargerBetter
new OpBinaryClassificationEvaluatorBase[SingleMetric](
uid = UID[OpBinaryClassificationEvaluatorBase[SingleMetric]]
) {
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = largerBetter
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName, isLargerBetter)
override def getDefaultMetric: SingleMetric => Double = _.value

override def evaluateAll(dataset: Dataset[_]): SingleMetric = {
Expand Down Expand Up @@ -165,7 +163,7 @@ object Evaluators {
* Weighted Precision
*/
def precision(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Precision, isLargerBetter = true) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Precision) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.Precision, dataset, default = 0.0)
}
Expand All @@ -174,7 +172,7 @@ object Evaluators {
* Weighted Recall
*/
def recall(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall, isLargerBetter = true) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.Recall, dataset, default = 0.0)
}
Expand All @@ -183,7 +181,7 @@ object Evaluators {
* F1 Score
*/
def f1(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1, isLargerBetter = true) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1) {
override def evaluate(dataset: Dataset[_]): Double =
getMultiEvaluatorMetric(MultiClassEvalMetrics.F1, dataset, default = 0.0)
}
Expand All @@ -192,7 +190,7 @@ object Evaluators {
* Prediction Error
*/
def error(): OpMultiClassificationEvaluator =
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error, isLargerBetter = false) {
new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error) {
override def evaluate(dataset: Dataset[_]): Double =
1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset, default = 1.0)
}
Expand All @@ -214,12 +212,10 @@ object Evaluators {
isLargerBetter: Boolean = true,
evaluateFn: Dataset[(Double, OPVector#Value, OPVector#Value, Double)] => Double
): OpMultiClassificationEvaluatorBase[SingleMetric] = {
val largerBetter = isLargerBetter
new OpMultiClassificationEvaluatorBase[SingleMetric](
uid = UID[OpMultiClassificationEvaluatorBase[SingleMetric]]
) {
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = largerBetter
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName, isLargerBetter)
override def getDefaultMetric: SingleMetric => Double = _.value

override def evaluateAll(dataset: Dataset[_]): SingleMetric = {
Expand Down Expand Up @@ -251,7 +247,7 @@ object Evaluators {
* Mean Squared Error
*/
def mse(): OpRegressionEvaluator =
new OpRegressionEvaluator(name = RegressionEvalMetrics.MeanSquaredError, isLargerBetter = false) {
new OpRegressionEvaluator(name = RegressionEvalMetrics.MeanSquaredError) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.MeanSquaredError, dataset, default = 0.0)
}
Expand All @@ -260,7 +256,7 @@ object Evaluators {
* Mean Absolute Error
*/
def mae(): OpRegressionEvaluator =
new OpRegressionEvaluator(name = RegressionEvalMetrics.MeanAbsoluteError, isLargerBetter = false) {
new OpRegressionEvaluator(name = RegressionEvalMetrics.MeanAbsoluteError) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.MeanAbsoluteError, dataset, default = 0.0)
}
Expand All @@ -269,7 +265,7 @@ object Evaluators {
* R2
*/
def r2(): OpRegressionEvaluator =
new OpRegressionEvaluator(name = RegressionEvalMetrics.R2, isLargerBetter = true) {
new OpRegressionEvaluator(name = RegressionEvalMetrics.R2) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.R2, dataset, default = 0.0)
}
Expand All @@ -278,7 +274,7 @@ object Evaluators {
* Root Mean Squared Error
*/
def rmse(): OpRegressionEvaluator =
new OpRegressionEvaluator(name = RegressionEvalMetrics.RootMeanSquaredError, isLargerBetter = false) {
new OpRegressionEvaluator(name = RegressionEvalMetrics.RootMeanSquaredError) {
override def evaluate(dataset: Dataset[_]): Double =
getRegEvaluatorMetric(RegressionEvalMetrics.RootMeanSquaredError, dataset, default = 0.0)
}
Expand All @@ -299,12 +295,10 @@ object Evaluators {
isLargerBetter: Boolean = true,
evaluateFn: Dataset[(Double, Double)] => Double
): OpRegressionEvaluatorBase[SingleMetric] = {
val largerBetter = isLargerBetter
new OpRegressionEvaluatorBase[SingleMetric](
uid = UID[OpRegressionEvaluatorBase[SingleMetric]]
) {
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName)
override val isLargerBetter: Boolean = largerBetter
override val name: EvalMetric = OpEvaluatorNames.Custom(metricName, metricName, isLargerBetter)
override def getDefaultMetric: SingleMetric => Double = _.value

override def evaluateAll(dataset: Dataset[_]): SingleMetric = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ import org.slf4j.LoggerFactory
private[op] class OpBinaryClassificationEvaluator
(
override val name: EvalMetric = OpEvaluatorNames.Binary,
override val isLargerBetter: Boolean = true,
override val uid: String = UID[OpBinaryClassificationEvaluator],
val numBins: Int = 100
) extends OpBinaryClassificationEvaluatorBase[BinaryClassificationMetrics](uid = uid) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,16 @@ abstract class OpEvaluatorBase[T <: EvaluationMetrics] extends Evaluator
*/
val name: EvalMetric


/**
* Use the definition of the metric to determine if larger is better
* @return
*/
override def isLargerBetter: Boolean = name.isLargerBetter

/**
* Evaluate function that returns a class or value with the calculated metric value(s).
*
* @param dataset data to evaluate
* @return metrics
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ private[op] class OpForecastEvaluator
val seasonalWindow: Int = 1,
val maxItems: Int = 87660,
override val name: EvalMetric = OpEvaluatorNames.Forecast,
override val isLargerBetter: Boolean = false,
override val uid: String = UID[OpForecastEvaluator]
) extends OpRegressionEvaluatorBase[ForecastMetrics](uid) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ import org.slf4j.LoggerFactory
private[op] class OpMultiClassificationEvaluator
(
override val name: EvalMetric = OpEvaluatorNames.Multi,
override val isLargerBetter: Boolean = true,
override val uid: String = UID[OpMultiClassificationEvaluator]
) extends OpMultiClassificationEvaluatorBase[MultiClassificationMetrics](uid) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import org.slf4j.LoggerFactory
private[op] class OpRegressionEvaluator
(
override val name: EvalMetric = OpEvaluatorNames.Regression,
override val isLargerBetter: Boolean = false,
override val uid: String = UID[OpRegressionEvaluator]
) extends OpRegressionEvaluatorBase[RegressionMetrics](uid) {

Expand Down
Loading

0 comments on commit 63bf74b

Please sign in to comment.