diff --git a/build.sbt b/build.sbt index d701adb3..1e6e667e 100644 --- a/build.sbt +++ b/build.sbt @@ -32,7 +32,8 @@ lazy val dataset = project .settings(publishSettings: _*) .settings(libraryDependencies ++= Seq( "org.apache.spark" %% "spark-core" % sparkVersion % "provided", - "org.apache.spark" %% "spark-sql" % sparkVersion % "provided" + "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", + "org.apache.spark" %% "spark-mllib" % sparkVersion % "test" )) .dependsOn(core % "test->test;compile->compile") diff --git a/dataset/src/main/scala/frameless/TypedEncoder.scala b/dataset/src/main/scala/frameless/TypedEncoder.scala index 7de38ccb..e2e62007 100644 --- a/dataset/src/main/scala/frameless/TypedEncoder.scala +++ b/dataset/src/main/scala/frameless/TypedEncoder.scala @@ -27,10 +27,13 @@ abstract class TypedEncoder[T](implicit val classTag: ClassTag[T]) extends Seria // Waiting on scala 2.12 // @annotation.implicitAmbiguous(msg = -// """TypedEncoder[${T}] can be obtained both from automatic type class derivation and using the implicit Injection[${T}, ?] in scope. To desambigious this resolution you need to either: +// """TypedEncoder[${T}] can be obtained from automatic type class derivation, using the implicit Injection[${T}, ?] or using the implicit UDT[${T]] in scope. +// To desambigious this resolution you need to either: // - Remove the implicit Injection[${T}, ?] from scope +// - Remove the implicit UDT[${T]] from scope // - import TypedEncoder.usingInjection // - import TypedEncoder.usingDerivation +// - import TypedEncoder.usingUDT // """) object TypedEncoder { def apply[T: TypedEncoder]: TypedEncoder[T] = implicitly[TypedEncoder[T]] @@ -327,4 +330,32 @@ object TypedEncoder { recordEncoder: Lazy[RecordEncoderFields[G]], classTag: ClassTag[F] ): TypedEncoder[F] = new RecordEncoder[F, G] + + type UDT[A >: Null] = FramelessInternals.PublicUserDefinedType[A] + + /** + * Encodes things using a Spark SQL's User Defined Type (UDT) if there is one defined in implicit. + * + * Example: to use Spark ML's VectorUDT implementation as a TypedEncoder, add the following in your scope: + * {{{ + * import org.apache.spark.ml.linalg.{Vector, SQLDataTypes} + * import frameless.TypedEncoder.UDT + * implicit val mLVectorUDT: UDT[Vector] = SQLDataTypes.VectorType.asInstanceOf[UDT[Vector]] + * }}} + * */ + implicit def usingUDT[A >: Null : UDT : ClassTag]: TypedEncoder[A] = { + val udt = implicitly[UDT[A]] + val udtInstance = NewInstance(udt.getClass, Nil, dataType = ObjectType(udt.getClass)) + + new TypedEncoder[A] { + def nullable: Boolean = false + def sourceDataType: DataType = ObjectType(udt.userClass) + def targetDataType: DataType = udt + + def extractorFor(path: Expression): Expression = Invoke(udtInstance, "serialize", udt, Seq(path)) + + def constructorFor(path: Expression): Expression = + Invoke(udtInstance, "deserialize", ObjectType(udt.userClass), Seq(path)) + } + } } diff --git a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala index d88dc81f..94660c1c 100644 --- a/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala +++ b/dataset/src/main/scala/org/apache/spark/sql/FramelessInternals.scala @@ -3,7 +3,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{ObjectType, UserDefinedType} import scala.reflect.ClassTag @@ -34,4 +34,6 @@ object FramelessInternals { def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = { Dataset.ofRows(sparkSession, logicalPlan) } + + type PublicUserDefinedType[A >: Null] = UserDefinedType[A] } diff --git a/dataset/src/test/scala/frameless/CollectTests.scala b/dataset/src/test/scala/frameless/CollectTests.scala index 8aae95bc..4430f845 100644 --- a/dataset/src/test/scala/frameless/CollectTests.scala +++ b/dataset/src/test/scala/frameless/CollectTests.scala @@ -4,7 +4,9 @@ import frameless.CollectTests.prop import org.apache.spark.sql.SQLContext import org.scalacheck.Prop import org.scalacheck.Prop._ + import scala.reflect.ClassTag +import org.apache.spark.ml.linalg.{Matrix, Vector => MLVector} class CollectTests extends TypedDatasetSuite { test("collect()") { @@ -48,6 +50,17 @@ class CollectTests extends TypedDatasetSuite { // TODO this doesn't work, and never worked... // check(forAll(prop[X1[Option[X1[Option[Int]]]]] _)) + + check(forAll(prop[MLVector] _)) + check(forAll(prop[Option[MLVector]] _)) + check(forAll(prop[X1[MLVector]] _)) + check(forAll(prop[X2[Int, MLVector]] _)) + check(forAll(prop[(Long, MLVector)] _)) + check(forAll(prop[Matrix] _)) + check(forAll(prop[Option[Matrix]] _)) + check(forAll(prop[X1[Matrix]] _)) + check(forAll(prop[X2[Int, Matrix]] _)) + check(forAll(prop[(Long, Matrix)] _)) } } diff --git a/dataset/src/test/scala/frameless/SelectTests.scala b/dataset/src/test/scala/frameless/SelectTests.scala index abcd0107..bc1a169f 100644 --- a/dataset/src/test/scala/frameless/SelectTests.scala +++ b/dataset/src/test/scala/frameless/SelectTests.scala @@ -4,7 +4,7 @@ import org.scalacheck.Prop import org.scalacheck.Prop._ import shapeless.test.illTyped import frameless.implicits.widen._ - +import org.apache.spark.ml.linalg.{Vector => MLVector} import scala.reflect.ClassTag class SelectTests extends TypedDatasetSuite { @@ -27,6 +27,7 @@ class SelectTests extends TypedDatasetSuite { check(forAll(prop[Int, Int, Int, Int] _)) check(forAll(prop[X2[Int, Int], Int, Int, Int] _)) check(forAll(prop[String, Int, Int, Int] _)) + check(forAll(prop[MLVector, Int, Int, Int] _)) } test("select('a, 'b) FROM abcd") { diff --git a/dataset/src/test/scala/frameless/package.scala b/dataset/src/test/scala/frameless/package.scala index 9525eb05..88b495f6 100644 --- a/dataset/src/test/scala/frameless/package.scala +++ b/dataset/src/test/scala/frameless/package.scala @@ -1,4 +1,6 @@ -import org.scalacheck.{Gen, Arbitrary} +import frameless.TypedEncoder.UDT +import org.scalacheck.{Arbitrary, Gen} +import org.apache.spark.ml.linalg.{Matrices, Matrix, SQLDataTypes, Vector => MLVector, Vectors => MLVectors} package object frameless { /** Fixed decimal point to avoid precision problems specific to Spark */ @@ -30,4 +32,24 @@ package object frameless { implicit def arbVector[A](implicit A: Arbitrary[A]): Arbitrary[Vector[A]] = Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector)) + implicit val arbMLVector: Arbitrary[MLVector] = Arbitrary { + val genDenseVector = Gen.listOf(arbDouble.arbitrary).map(doubles => MLVectors.dense(doubles.toArray)) + val genSparseVector = genDenseVector.map(_.toSparse) + + Gen.oneOf(genDenseVector, genSparseVector) + } + + implicit val mLVectorUDT: UDT[MLVector] = SQLDataTypes.VectorType.asInstanceOf[UDT[MLVector]] + + implicit val arbMatrix: Arbitrary[Matrix] = Arbitrary { + Gen.sized { nbRows => + Gen.sized { nbCols => + Gen.listOfN(nbRows * nbCols, arbDouble.arbitrary) + .map(values => Matrices.dense(nbRows, nbCols, values.toArray)) + } + } + } + + implicit val matrixUDT: UDT[Matrix] = SQLDataTypes.MatrixType.asInstanceOf[UDT[Matrix]] + }