Skip to content

Commit

Permalink
Fixing equality on nullable types (typelevel#152)
Browse files Browse the repository at this point in the history
* Fixing equality on nullable types

* Addressing code reviews. Equality now works differently when the target type is a spark struct().
  • Loading branch information
imarios committed Sep 5, 2017
1 parent 41d6ee2 commit 919cdd5
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
37 changes: 32 additions & 5 deletions dataset/src/main/scala/frameless/TypedColumn.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,21 @@ sealed class TypedColumn[T, U](
*/
def untyped: Column = new Column(expr)

private def withExpr(newExpr: Expression): Column = new Column(newExpr)

private def equalsTo(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = withExpr {
if (uencoder.nullable && uencoder.targetDataType.typeName != "struct") EqualNullSafe(self.expr, other.expr)
else EqualTo(self.expr, other.expr)
}.typed

/** Equality test.
* {{{
* df.filter( df.col('a) === 1 )
* }}}
*
* apache/spark
*/
def ===(other: U): TypedColumn[T, Boolean] = (untyped === lit(other).untyped).typed
def ===(other: U): TypedColumn[T, Boolean] = equalsTo(lit(other))

/** Equality test.
* {{{
Expand All @@ -60,7 +67,7 @@ sealed class TypedColumn[T, U](
*
* apache/spark
*/
def ===(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (untyped === other.untyped).typed
def ===(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = equalsTo(other)

/** Inequality test.
* {{{
Expand All @@ -69,7 +76,9 @@ sealed class TypedColumn[T, U](
*
* apache/spark
*/
def =!=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = (self.untyped =!= other.untyped).typed
def =!=(other: TypedColumn[T, U]): TypedColumn[T, Boolean] = withExpr {
Not(equalsTo(other).expr)
}.typed

/** Inequality test.
* {{{
Expand All @@ -78,7 +87,24 @@ sealed class TypedColumn[T, U](
*
* apache/spark
*/
def =!=(other: U): TypedColumn[T, Boolean] = (self.untyped =!= lit(other).untyped).typed
def =!=(other: U): TypedColumn[T, Boolean] = withExpr {
Not(equalsTo(lit(other)).expr)
}.typed

/** True if the current expression is an Option and it's None.
*
* apache/spark
*/
def isNone(implicit isOption: U <:< Option[_]): TypedColumn[T, Boolean] =
equalsTo(lit[U,T](None.asInstanceOf[U]))

/** True if the current expression is an Option and it's not None.
*
* apache/spark
*/
def isNotNone(implicit isOption: U <:< Option[_]): TypedColumn[T, Boolean] = withExpr {
Not(equalsTo(lit(None.asInstanceOf[U])).expr)
}.typed

/** Sum of this expression and another expression.
* {{{
Expand Down Expand Up @@ -195,7 +221,8 @@ sealed class TypedColumn[T, U](
* @param u another column of the same type
* apache/spark
*/
def divide(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, Double] = self.untyped.divide(u.untyped).typed
def divide(u: TypedColumn[T, U])(implicit n: CatalystNumeric[U]): TypedColumn[T, Double] =
self.untyped.divide(u.untyped).typed

/**
* Division this expression by another expression.
Expand Down
44 changes: 44 additions & 0 deletions dataset/src/test/scala/frameless/FilterTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,48 @@ class FilterTests extends TypedDatasetSuite {
assert(tds.filter(tds('a) * 2 === 2).collect().run().toVector === Vector(X1(1)))
assert(tds.filter(tds('a) * 3 === 3).collect().run().toVector === Vector(X1(1)))
}

test("Option equality/inequality for columns") {
def prop[A <: Option[_] : TypedEncoder](a: A, b: A): Prop = {
val data = X2(a, b) :: X2(a, a) :: Nil
val dataset = TypedDataset.create(data)
val A = dataset.col('a)
val B = dataset.col('b)

(data.filter(x => x.a == x.b).toSet ?= dataset.filter(A === B).collect().run().toSet).
&&(data.filter(x => x.a != x.b).toSet ?= dataset.filter(A =!= B).collect().run().toSet).
&&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNone).collect().run().toSet).
&&(data.filter(x => x.a == None).toSet ?= dataset.filter(A.isNotNone === false).collect().run().toSet)
}

check(forAll(prop[Option[Int]] _))
check(forAll(prop[Option[Boolean]] _))
check(forAll(prop[Option[SQLDate]] _))
check(forAll(prop[Option[SQLTimestamp]] _))
check(forAll(prop[Option[X1[String]]] _))
check(forAll(prop[Option[X1[X1[String]]]] _))
check(forAll(prop[Option[X1[X1[Vector[Option[Int]]]]]] _))
}

test("Option equality/inequality for lit") {
def prop[A <: Option[_] : TypedEncoder](a: A, b: A, cLit: A): Prop = {
val data = X2(a, b) :: X2(a, cLit) :: Nil
val dataset = TypedDataset.create(data)
val colA = dataset.col('a)

(data.filter(x => x.a == cLit).toSet ?= dataset.filter(colA === cLit).collect().run().toSet).
&&(data.filter(x => x.a != cLit).toSet ?= dataset.filter(colA =!= cLit).collect().run().toSet).
&&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNone).collect().run().toSet).
&&(data.filter(x => x.a == None).toSet ?= dataset.filter(colA.isNotNone === false).collect().run().toSet)
}

check(forAll(prop[Option[Int]] _))
check(forAll(prop[Option[Boolean]] _))
check(forAll(prop[Option[SQLDate]] _))
check(forAll(prop[Option[SQLTimestamp]] _))
check(forAll(prop[Option[String]] _))
check(forAll(prop[Option[X1[String]]] _))
check(forAll(prop[Option[X1[X1[String]]]] _))
check(forAll(prop[Option[X1[X1[Vector[Option[Int]]]]]] _))
}
}

0 comments on commit 919cdd5

Please sign in to comment.