Skip to content

Commit

Permalink
Add metadata to OpStandardScaler to allow for descaling (#378)
Browse files Browse the repository at this point in the history
* Add metadata to OpScalarStandardScaler

* Add test for descaling to OpScalarStandardScaler tests

* Fix scalastyle

* Improve tests

* Add explanation to selection of std and mean

* Add clarification
  • Loading branch information
erica-chiu committed Aug 5, 2019
1 parent e1bab3b commit 9cf80fd
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,19 @@ class OpScalarStandardScaler
val internalScaler = new MLLibStandardScaler(withMean = estimator.getWithMean, withStd = estimator.getWithStd)
val scalerModel = internalScaler.fit(vecData)

val std = scalerModel.std.toArray
val mean = scalerModel.mean.toArray

// Since is a UnaryEstimator from RealNN to RealNN, exactly one value will be in std and mean
val stdVal = std.head
val meanVal = mean.head
val scalingArgs = LinearScalerArgs(1 / stdVal, - meanVal / stdVal)
val meta = ScalerMetadata(ScalingType.Linear, scalingArgs).toMetadata()
setMetadata(meta)

new OpScalarStandardScalerModel(
std = scalerModel.std.toArray,
mean = scalerModel.mean.toArray,
std = std,
mean = mean,
withStd = scalerModel.withStd,
withMean = scalerModel.withMean,
operationName = operationName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.util.{Failure, Success}


@RunWith(classOf[JUnitRunner])
class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[RealNN, RealNN], OpScalarStandardScaler] {
Expand All @@ -59,6 +61,8 @@ class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[Real
1.150792911137501.toRealNN
)

val descaleValues = Seq(10.0, 100.0, 1000.0)
val (descaleData, testD) = TestFeatureBuilder(descaleValues.map(_.toRealNN))

val (inputData, testF) = TestFeatureBuilder(Seq(10, 100, 1000).map(_.toRealNN))

Expand Down Expand Up @@ -148,6 +152,39 @@ class OpScalarStandardScalerTest extends OpEstimatorSpec[RealNN, UnaryModel[Real
assert(sumSqDist <= 0.000001, "===> the sum of squared distances between actual and expected should be zero.")
}

it should "descale and work in standardized workflow" in {
val featureNormalizer = new OpScalarStandardScaler().setInput(testD)
val normedOutput = featureNormalizer.getOutput()
val metadata = featureNormalizer.fit(descaleData).getMetadata()

val expectedMean = descaleValues.sum / descaleValues.length
val expectedStd = math.sqrt(descaleValues.map(value => math.pow(expectedMean - value, 2)).sum
/ (descaleValues.length - 1))
val expectedSlope = 1 / expectedStd
val expectedIntercept = - expectedMean / expectedStd
ScalerMetadata(metadata) match {
case Failure(err) => fail(err)
case Success(meta) =>
meta shouldBe a[ScalerMetadata]
meta.scalingType shouldBe ScalingType.Linear
meta.scalingArgs shouldBe a[LinearScalerArgs]
math.abs((meta.scalingArgs.asInstanceOf[LinearScalerArgs].slope - expectedSlope)
/ expectedSlope) should be < 0.001
math.abs((meta.scalingArgs.asInstanceOf[LinearScalerArgs].intercept - expectedIntercept)
/ expectedIntercept) should be < 0.001
}

val descaledResponse = new DescalerTransformer[RealNN, RealNN, RealNN]()
.setInput(normedOutput, normedOutput).getOutput()
val workflow = new OpWorkflow().setResultFeatures(descaledResponse)
val wfModel = workflow.setInputDataset(descaleData).train()
val transformed = wfModel.score()

val actual = transformed.collect().map(_.getAs[Double](1))
val expected : Seq[Double] = descaleValues
all(actual.zip(expected).map(x => math.abs(x._2 - x._1))) should be < 0.0001
}

private def validateDataframeDoubleColumn(normalizedFeatureDF: DataFrame, scaledFeatureName: String,
targetColumnName: String): Double = {
val sqDistUdf = udf { (leftCol: Double, rightCol: Double) => Math.pow(leftCol - rightCol, 2) }
Expand Down

0 comments on commit 9cf80fd

Please sign in to comment.