diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/tuning/DataBalancer.scala b/core/src/main/scala/com/salesforce/op/stages/impl/tuning/DataBalancer.scala index 13ba9aea22..5970b4cc70 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/tuning/DataBalancer.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/tuning/DataBalancer.scala @@ -338,24 +338,6 @@ trait DataBalancerParams extends Params { def getSampleFraction: Double = $(sampleFraction) - /** - * Maximum size of dataset want to train on. - * Value should be > 0. - * Default is 5000. - * - * @group param - */ - final val maxTrainingSample = new IntParam(this, "maxTrainingSample", - "maximum size of dataset want to train on", ParamValidators.inRange( - lowerBound = 0, upperBound = 1 << 30, lowerInclusive = false, upperInclusive = true - ) - ) - setDefault(maxTrainingSample, SplitterParamsDefault.MaxTrainingSampleDefault) - - def setMaxTrainingSample(value: Int): this.type = set(maxTrainingSample, value) - - def getMaxTrainingSample: Int = $(maxTrainingSample) - /** * Fraction to sample minority data * Value should be > 0.0 diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/tuning/Splitter.scala b/core/src/main/scala/com/salesforce/op/stages/impl/tuning/Splitter.scala index 99734a1fa1..3808e28a03 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/tuning/Splitter.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/tuning/Splitter.scala @@ -129,6 +129,24 @@ trait SplitterParams extends Params { def setReserveTestFraction(value: Double): this.type = set(reserveTestFraction, value) def getReserveTestFraction: Double = $(reserveTestFraction) + /** + * Maximum size of dataset want to train on. + * Value should be > 0. + * Default is 1000000. + * + * @group param + */ + final val maxTrainingSample = new IntParam(this, "maxTrainingSample", + "maximum size of dataset want to train on", ParamValidators.inRange( + lowerBound = 0, upperBound = 1 << 30, lowerInclusive = false, upperInclusive = true + ) + ) + setDefault(maxTrainingSample, SplitterParamsDefault.MaxTrainingSampleDefault) + + def setMaxTrainingSample(value: Int): this.type = set(maxTrainingSample, value) + + def getMaxTrainingSample: Int = $(maxTrainingSample) + final val labelColumnName = new Param[String](this, "labelColumnName", "label column name, column 0 if not specified") private[op] def getLabelColumnName = $(labelColumnName)