diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScaler.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScaler.scala index fcf1e762f9..1036b3f00b 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScaler.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScaler.scala @@ -60,9 +60,19 @@ class OpScalarStandardScaler val internalScaler = new MLLibStandardScaler(withMean = estimator.getWithMean, withStd = estimator.getWithStd) val scalerModel = internalScaler.fit(vecData) + val std = scalerModel.std.toArray + val mean = scalerModel.mean.toArray + + // Since is a UnaryEstimator from RealNN to RealNN, exactly one value will be in std and mean + val stdVal = std.head + val meanVal = mean.head + val scalingArgs = LinearScalerArgs(1 / stdVal, - meanVal / stdVal) + val meta = ScalerMetadata(ScalingType.Linear, scalingArgs).toMetadata() + setMetadata(meta) + new OpScalarStandardScalerModel( - std = scalerModel.std.toArray, - mean = scalerModel.mean.toArray, + std = std, + mean = mean, withStd = scalerModel.withStd, withMean = scalerModel.withMean, operationName = operationName, diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScalerTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScalerTest.scala index 91754c7fb8..d5bf1f3fd1 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScalerTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/OpScalarStandardScalerTest.scala @@ -44,6 +44,8 @@ import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner +import scala.util.{Failure, Success} + @RunWith(classOf[JUnitRunner]) class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[RealNN, RealNN], OpScalarStandardScaler] { @@ -59,6 +61,8 @@ class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[Real 1.150792911137501.toRealNN ) + val descaleValues = Seq(10.0, 100.0, 1000.0) + val (descaleData, testD) = TestFeatureBuilder(descaleValues.map(_.toRealNN)) val (inputData, testF) = TestFeatureBuilder(Seq(10, 100, 1000).map(_.toRealNN)) @@ -148,6 +152,39 @@ class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[Real assert(sumSqDist <= 0.000001, "===> the sum of squared distances between actual and expected should be zero.") } + it should "descale and work in standardized workflow" in { + val featureNormalizer = new OpScalarStandardScaler().setInput(testD) + val normedOutput = featureNormalizer.getOutput() + val metadata = featureNormalizer.fit(descaleData).getMetadata() + + val expectedMean = descaleValues.sum / descaleValues.length + val expectedStd = math.sqrt(descaleValues.map(value => math.pow(expectedMean - value, 2)).sum + / (descaleValues.length - 1)) + val expectedSlope = 1 / expectedStd + val expectedIntercept = - expectedMean / expectedStd + ScalerMetadata(metadata) match { + case Failure(err) => fail(err) + case Success(meta) => + meta shouldBe a[ScalerMetadata] + meta.scalingType shouldBe ScalingType.Linear + meta.scalingArgs shouldBe a[LinearScalerArgs] + math.abs((meta.scalingArgs.asInstanceOf[LinearScalerArgs].slope - expectedSlope) + / expectedSlope) should be < 0.001 + math.abs((meta.scalingArgs.asInstanceOf[LinearScalerArgs].intercept - expectedIntercept) + / expectedIntercept) should be < 0.001 + } + + val descaledResponse = new DescalerTransformer[RealNN, RealNN, RealNN]() + .setInput(normedOutput, normedOutput).getOutput() + val workflow = new OpWorkflow().setResultFeatures(descaledResponse) + val wfModel = workflow.setInputDataset(descaleData).train() + val transformed = wfModel.score() + + val actual = transformed.collect().map(_.getAs[Double](1)) + val expected : Seq[Double] = descaleValues + all(actual.zip(expected).map(x => math.abs(x._2 - x._1))) should be < 0.0001 + } + private def validateDataframeDoubleColumn(normalizedFeatureDF: DataFrame, scaledFeatureName: String, targetColumnName: String): Double = { val sqDistUdf = udf { (leftCol: Double, rightCol: Double) => Math.pow(leftCol - rightCol, 2) }