Skip to content

Commit

Permalink
add UDT support for TypedEncoder
Browse files Browse the repository at this point in the history
  • Loading branch information
atamborrino committed Aug 21, 2017
1 parent 920d167 commit e8de534
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 5 deletions.
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ lazy val dataset = project
.settings(publishSettings: _*)
.settings(libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-mllib" % sparkVersion % "test"
))
.dependsOn(core % "test->test;compile->compile")

Expand Down
33 changes: 32 additions & 1 deletion dataset/src/main/scala/frameless/TypedEncoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Seria

// Waiting on scala 2.12
// @annotation.implicitAmbiguous(msg =
// """TypedEncoder[${T}] can be obtained both from automatic type class derivation and using the implicit Injection[${T}, ?] in scope. To desambigious this resolution you need to either:
// """TypedEncoder[${T}] can be obtained from automatic type class derivation, using the implicit Injection[${T}, ?] or using the implicit UDT[${T]] in scope.
// To desambigious this resolution you need to either:
// - Remove the implicit Injection[${T}, ?] from scope
// - Remove the implicit UDT[${T]] from scope
// - import TypedEncoder.usingInjection
// - import TypedEncoder.usingDerivation
// - import TypedEncoder.usingUDT
// """)
object TypedEncoder {
def apply[T: TypedEncoder]: TypedEncoder[T] = implicitly[TypedEncoder[T]]
Expand Down Expand Up @@ -327,4 +330,32 @@ object TypedEncoder {
recordEncoder: Lazy[RecordEncoderFields[G]],
classTag: ClassTag[F]
): TypedEncoder[F] = new RecordEncoder[F, G]

type UDT[A >: Null] = FramelessInternals.PublicUserDefinedType[A]

/**
* Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit.
*
* Example: to use Spark ML's VectorUDT implementation as a TypedEncoder, add the following in your scope:
* {{{
* import org.apache.spark.ml.linalg.{Vector, SQLDataTypes}
* import frameless.TypedEncoder.UDT
* implicit val mLVectorUDT: UDT[Vector] = SQLDataTypes.VectorType.asInstanceOf[UDT[Vector]]
* }}}
* */
implicit def usingUDT[A >: Null : UDT : ClassTag]: TypedEncoder[A] = {
val udt = implicitly[UDT[A]]
val udtInstance = NewInstance(udt.getClass, Nil, dataType = ObjectType(udt.getClass))

new TypedEncoder[A] {
def nullable: Boolean = false
def sourceDataType: DataType = ObjectType(udt.userClass)
def targetDataType: DataType = udt

def extractorFor(path: Expression): Expression = Invoke(udtInstance, "serialize", udt, Seq(path))

def constructorFor(path: Expression): Expression =
Invoke(udtInstance, "deserialize", ObjectType(udt.userClass), Seq(path))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.ObjectType
import org.apache.spark.sql.types.{ObjectType, UserDefinedType}

import scala.reflect.ClassTag

Expand Down Expand Up @@ -34,4 +34,6 @@ object FramelessInternals {
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
Dataset.ofRows(sparkSession, logicalPlan)
}

type PublicUserDefinedType[A >: Null] = UserDefinedType[A]
}
13 changes: 13 additions & 0 deletions dataset/src/test/scala/frameless/CollectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import frameless.CollectTests.prop
import org.apache.spark.sql.SQLContext
import org.scalacheck.Prop
import org.scalacheck.Prop._

import scala.reflect.ClassTag
import org.apache.spark.ml.linalg.{Matrix, Vector => MLVector}

class CollectTests extends TypedDatasetSuite {
test("collect()") {
Expand Down Expand Up @@ -48,6 +50,17 @@ class CollectTests extends TypedDatasetSuite {

// TODO this doesn't work, and never worked...
// check(forAll(prop[X1[Option[X1[Option[Int]]]]] _))

check(forAll(prop[MLVector] _))
check(forAll(prop[Option[MLVector]] _))
check(forAll(prop[X1[MLVector]] _))
check(forAll(prop[X2[Int, MLVector]] _))
check(forAll(prop[(Long, MLVector)] _))
check(forAll(prop[Matrix] _))
check(forAll(prop[Option[Matrix]] _))
check(forAll(prop[X1[Matrix]] _))
check(forAll(prop[X2[Int, Matrix]] _))
check(forAll(prop[(Long, Matrix)] _))
}
}

Expand Down
3 changes: 2 additions & 1 deletion dataset/src/test/scala/frameless/SelectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import org.scalacheck.Prop
import org.scalacheck.Prop._
import shapeless.test.illTyped
import frameless.implicits.widen._

import org.apache.spark.ml.linalg.{Vector => MLVector}
import scala.reflect.ClassTag

class SelectTests extends TypedDatasetSuite {
Expand All @@ -27,6 +27,7 @@ class SelectTests extends TypedDatasetSuite {
check(forAll(prop[Int, Int, Int, Int] _))
check(forAll(prop[X2[Int, Int], Int, Int, Int] _))
check(forAll(prop[String, Int, Int, Int] _))
check(forAll(prop[MLVector, Int, Int, Int] _))
}

test("select('a, 'b) FROM abcd") {
Expand Down
24 changes: 23 additions & 1 deletion dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import org.scalacheck.{Gen, Arbitrary}
import frameless.TypedEncoder.UDT
import org.scalacheck.{Arbitrary, Gen}
import org.apache.spark.ml.linalg.{Matrices, Matrix, SQLDataTypes, Vector => MLVector, Vectors => MLVectors}

package object frameless {
/** Fixed decimal point to avoid precision problems specific to Spark */
Expand Down Expand Up @@ -30,4 +32,24 @@ package object frameless {
implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector))

implicit val arbMLVector: Arbitrary[MLVector] = Arbitrary {
val genDenseVector = Gen.listOf(arbDouble.arbitrary).map(doubles => MLVectors.dense(doubles.toArray))
val genSparseVector = genDenseVector.map(_.toSparse)

Gen.oneOf(genDenseVector, genSparseVector)
}

implicit val mLVectorUDT: UDT[MLVector] = SQLDataTypes.VectorType.asInstanceOf[UDT[MLVector]]

implicit val arbMatrix: Arbitrary[Matrix] = Arbitrary {
Gen.sized { nbRows =>
Gen.sized { nbCols =>
Gen.listOfN(nbRows * nbCols, arbDouble.arbitrary)
.map(values => Matrices.dense(nbRows, nbCols, values.toArray))
}
}
}

implicit val matrixUDT: UDT[Matrix] = SQLDataTypes.MatrixType.asInstanceOf[UDT[Matrix]]

}

0 comments on commit e8de534

Please sign in to comment.