Skip to content

Commit

Permalink
lifted out the common pattern for bivariate statistics properties
Browse files Browse the repository at this point in the history
  • Loading branch information
GrafBlutwurst committed Jul 10, 2017
1 parent 78ff784 commit bfeb1e7
Showing 1 changed file with 48 additions and 51 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package frameless
package functions

import frameless.{TypedAggregate, TypedColumn}
import frameless.functions.aggregate._
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.{Column, Encoder}
import org.scalacheck.{Gen, Prop}
import org.scalacheck.Prop._

Expand Down Expand Up @@ -359,6 +360,50 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal] _))
}


def biVariatePropTemplate[A: TypedEncoder, B: TypedEncoder]
(
xs: List[X3[Int,A,B]]
)
(

framelessFun: (TypedColumn[X3[Int,A,B], A], TypedColumn[X3[Int,A,B], B]) => TypedAggregate[X3[Int,A,B], Option[Double]],
sparkFun: (Column, Column) => Column
)
(
implicit
encEv: Encoder[(Int, A, B)],
encEv2: Encoder[(Int,Option[Double])],
evCanBeDoubleA: CatalystCast[A, Double],
evCanBeDoubleB: CatalystCast[B, Double]
): Prop = {


val tds = TypedDataset.create(xs)
//typed implementation of bivar stats function
val tdBivar = tds.groupBy(tds('a)).agg(framelessFun(tds('b), tds('c)))
.map(
kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
).collect().run()



val cDF = session.createDataset(xs.map(x => (x.a, x.b, x.c)))
//comparison implementation of bivar stats functions
val compBivar = cDF
.groupBy(cDF("_1"))
.agg(sparkFun(cDF("_2"), cDF("_3")))
.map(
row => {
val grp = row.getInt(0)
(grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
}
)

//should be the same
tdBivar.toMap ?= compBivar.collect().toMap
}

test("corr") {
val spark = session
import spark.implicits._
Expand All @@ -368,31 +413,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
encEv: Encoder[(Int, A, B)],
evCanBeDoubleA: CatalystCast[A, Double],
evCanBeDoubleB: CatalystCast[B, Double]
): Prop = {


val tds = TypedDataset.create(xs)
val tdCorrelation = tds.groupBy(tds('a)).agg(corr(tds('b), tds('c)))
.map(
kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
).collect().run()



val cDF = session.createDataset(xs.map(x => (x.a,x.b,x.c)))
val compCorrelation = cDF
.groupBy(cDF("_1"))
.agg(org.apache.spark.sql.functions.corr(cDF("_2"), cDF("_3")))
.map(
row => {
val grp = row.getInt(0)
(grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
}
)


tdCorrelation.toMap ?= compCorrelation.collect().toMap
}
): Prop = biVariatePropTemplate(xs)(corr[A,B,X3[Int, A, B]],org.apache.spark.sql.functions.corr)

check(forAll(prop[Double, Double] _))
check(forAll(prop[Double, Int] _))
Expand All @@ -410,31 +431,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
encEv: Encoder[(Int, A, B)],
evCanBeDoubleA: CatalystCast[A, Double],
evCanBeDoubleB: CatalystCast[B, Double]
): Prop = {


val tds = TypedDataset.create(xs)
val tdCovar = tds.groupBy(tds('a)).agg(covar_pop(tds('b), tds('c)))
.map(
kv => (kv._1, kv._2.flatMap(DoubleBehaviourUtils.nanNullHandler))
).collect().run()



val cDF = session.createDataset(xs.map(x => (x.a,x.b,x.c)))
val compCovar = cDF
.groupBy(cDF("_1"))
.agg(org.apache.spark.sql.functions.covar_pop(cDF("_2"), cDF("_3")))
.map(
row => {
val grp = row.getInt(0)
(grp, DoubleBehaviourUtils.nanNullHandler(row.get(1)))
}
)


tdCovar.toMap ?= compCovar.collect().toMap
}
): Prop = biVariatePropTemplate(xs)(covar_pop[A,B,X3[Int, A, B]],org.apache.spark.sql.functions.covar_pop)

check(forAll(prop[Double, Double] _))
check(forAll(prop[Double, Int] _))
Expand Down

0 comments on commit bfeb1e7

Please sign in to comment.