Skip to content

Commit

Permalink
Parameterize Spark actions over the effect type used (typelevel#172)
Browse files Browse the repository at this point in the history
  • Loading branch information
iravid authored and imarios committed Sep 17, 2017
1 parent 50c491e commit 68aa838
Show file tree
Hide file tree
Showing 21 changed files with 305 additions and 52 deletions.
20 changes: 17 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
val sparkVersion = "2.2.0"
val catsv = "0.9.0"
val catsCoreVersion = "1.0.0-MF"
val catsEffectVersion = "0.4"
val catsMtlVersion = "0.0.2"
val scalatest = "3.0.3"
val shapeless = "2.3.2"
val scalacheck = "1.13.5"
Expand All @@ -20,9 +22,17 @@ lazy val cats = project
.settings(framelessSettings: _*)
.settings(warnUnusedImport: _*)
.settings(publishSettings: _*)
.settings(
addCompilerPlugin("org.spire-math" %% "kind-projector" % "0.9.4"),
scalacOptions += "-Ypartial-unification"
)
.settings(libraryDependencies ++= Seq(
"org.typelevel" %% "cats" % catsv,
"org.apache.spark" %% "spark-core" % sparkVersion % "provided"))
"org.typelevel" %% "cats-core" % catsCoreVersion,
"org.typelevel" %% "cats-effect" % catsEffectVersion,
"org.typelevel" %% "cats-mtl-core" % catsMtlVersion,
"org.apache.spark" %% "spark-core" % sparkVersion % "provided",
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided"))
.dependsOn(dataset % "test->test;compile->compile")

lazy val dataset = project
.settings(name := "frameless-dataset")
Expand Down Expand Up @@ -62,6 +72,10 @@ lazy val docs = project
"org.apache.spark" %% "spark-sql" % sparkVersion,
"org.apache.spark" %% "spark-mllib" % sparkVersion
))
.settings(
addCompilerPlugin("org.spire-math" %% "kind-projector" % "0.9.4"),
scalacOptions += "-Ypartial-unification"
)
.dependsOn(dataset, cats, ml)

lazy val framelessSettings = Seq(
Expand Down
24 changes: 24 additions & 0 deletions cats/src/main/scala/frameless/cats/FramelessSyntax.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package frameless
package cats

import _root_.cats.effect.Sync
import _root_.cats.implicits._
import _root_.cats.mtl.ApplicativeAsk
import org.apache.spark.sql.SparkSession

trait FramelessSyntax extends frameless.FramelessSyntax {
implicit class SparkJobOps[F[_], A](fa: F[A])(implicit S: Sync[F], A: ApplicativeAsk[F, SparkSession]) {
import S._, A._

def withLocalProperty(key: String, value: String): F[A] =
for {
session <- ask
_ <- delay(session.sparkContext.setLocalProperty(key, value))
a <- fa
} yield a

def withGroupId(groupId: String): F[A] = withLocalProperty("spark.jobGroup.id", groupId)

def withDescription(description: String) = withLocalProperty("spark.job.description", description)
}
}
11 changes: 11 additions & 0 deletions cats/src/main/scala/frameless/cats/SparkDelayInstances.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package frameless
package cats

import _root_.cats.effect.Sync
import org.apache.spark.sql.SparkSession

trait SparkDelayInstances {
implicit def framelessCatsSparkDelayForSync[F[_]](implicit S: Sync[F]): SparkDelay[F] = new SparkDelay[F] {
def delay[A](a: => A)(implicit spark: SparkSession): F[A] = S.delay(a)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import _root_.cats._
import scala.reflect.ClassTag
import org.apache.spark.rdd.RDD

object implicits {
object implicits extends FramelessSyntax with SparkDelayInstances {
implicit class rddOps[A: ClassTag](lhs: RDD[A]) {
def csum(implicit m: Monoid[A]): A = lhs.reduce(_ |+| _)
def cmin(implicit o: Order[A]): A = lhs.reduce(_ min _)
Expand Down
File renamed without changes.
63 changes: 63 additions & 0 deletions cats/src/test/scala/frameless/cats/FramelessSyntaxTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package frameless
package cats

import _root_.cats.data.ReaderT
import _root_.cats.effect.{ IO, Sync }
import frameless.{ TypedDataset, TypedDatasetSuite, TypedEncoder, X2 }
import org.apache.spark.sql.SparkSession
import org.scalacheck.Prop, Prop._

class FramelessSyntaxTests extends TypedDatasetSuite {
override val sparkDelay = null

def prop[A, B](data: Vector[X2[A, B]])(
implicit ev: TypedEncoder[X2[A, B]]
): Prop = {
import implicits._

val dataset = TypedDataset.create(data).dataset
val dataframe = dataset.toDF()

val typedDataset = dataset.typed
val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]]

typedDataset.collect[IO]().unsafeRunSync().toVector ?= typedDatasetFromDataFrame.collect[IO]().unsafeRunSync().toVector
}

test("dataset typed - toTyped") {
check(forAll(prop[Int, String] _))
}

test("properties can be read back") {
import implicits._
import _root_.cats.implicits._
import _root_.cats.mtl.implicits._

// We need this instance here because there is no cats.effect.Sync instance for ReaderT.
// Hopefully the instance will be back before cats 1.0.0 and we'll be able to get rid of this.
implicit val sync: Sync[ReaderT[IO, SparkSession, ?]] = new Sync[ReaderT[IO, SparkSession, ?]] {
def suspend[A](thunk: => ReaderT[IO, SparkSession, A]) = thunk
def pure[A](x: A): ReaderT[IO, SparkSession, A] = ReaderT.pure(x)
def handleErrorWith[A](fa: ReaderT[IO, SparkSession, A])(f: Throwable => ReaderT[IO, SparkSession, A]): ReaderT[IO, SparkSession, A] =
ReaderT(r => fa.run(r).handleErrorWith(e => f(e).run(r)))
def raiseError[A](e: Throwable): ReaderT[IO, SparkSession, A] = ReaderT.lift(IO.raiseError(e))
def flatMap[A, B](fa: ReaderT[IO, SparkSession, A])(f: A => ReaderT[IO, SparkSession, B]): ReaderT[IO, SparkSession, B] = fa.flatMap(f)
def tailRecM[A, B](a: A)(f: A => ReaderT[IO, SparkSession, Either[A, B]]): ReaderT[IO, SparkSession, B] =
ReaderT.catsDataMonadForKleisli[IO, SparkSession].tailRecM(a)(f)
}

check {
forAll { (k:String, v: String) =>
val scopedKey = "frameless.tests." + k
1.pure[ReaderT[IO, SparkSession, ?]].withLocalProperty(scopedKey,v).run(session).unsafeRunSync()
sc.getLocalProperty(scopedKey) ?= v

1.pure[ReaderT[IO, SparkSession, ?]].withGroupId(v).run(session).unsafeRunSync()
sc.getLocalProperty("spark.jobGroup.id") ?= v

1.pure[ReaderT[IO, SparkSession, ?]].withDescription(v).run(session).unsafeRunSync()
sc.getLocalProperty("spark.job.description") ?= v
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package frameless
package cats
package bec

import cats.implicits._
import _root_.cats.implicits._
import org.apache.spark.SparkContext
import org.apache.spark.sql.SparkSession
import org.scalatest.Matchers
import org.scalacheck.Arbitrary
import Arbitrary._
Expand All @@ -15,18 +17,20 @@ import org.scalactic.anyvals.PosInt
import scala.reflect.ClassTag

trait SparkTests {
val appID: String = new java.util.Date().toString + math.floor(math.random * 10E4).toLong.toString

implicit lazy val sc: SC =
new SC(conf)
val conf: SparkConf = new SparkConf()
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
.set("spark.app.id", appID)

implicit def session: SparkSession = SparkSession.builder().config(conf).getOrCreate()
implicit def sc: SparkContext = session.sparkContext

implicit class seqToRdd[A: ClassTag](seq: Seq[A])(implicit sc: SC) {
def toRdd: RDD[A] = sc.makeRDD(seq)
}

lazy val conf: SparkConf =
new SparkConf()
.setMaster("local[4]")
.setAppName("cats.bec test")
}

object Tests {
Expand Down
4 changes: 4 additions & 0 deletions dataset/src/main/scala/frameless/Job.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,8 @@ object Job {
def apply[A](a: => A)(implicit spark: SparkSession): Job[A] = new Job[A] {
def run(): A = a
}

implicit val framelessSparkDelayForJob: SparkDelay[Job] = new SparkDelay[Job] {
def delay[A](a: => A)(implicit spark: SparkSession): Job[A] = Job(a)
}
}
7 changes: 7 additions & 0 deletions dataset/src/main/scala/frameless/SparkDelay.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package frameless

import org.apache.spark.sql.SparkSession

trait SparkDelay[F[_]] {
def delay[A](a: => A)(implicit spark: SparkSession): F[A]
}
44 changes: 22 additions & 22 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val

/** Returns the number of elements in the [[TypedDataset]].
*
* Differs from `Dataset#count` by wrapping it's result into a [[Job]].
* Differs from `Dataset#count` by wrapping it's result into an effect-suspending `F[_]`.
*/
def count(): Job[Long] =
Job(dataset.count)
def count[F[_]]()(implicit F: SparkDelay[F]): F[Long] =
F.delay(dataset.count)

/** Returns `TypedColumn` of type `A` given it's name.
*
Expand Down Expand Up @@ -188,20 +188,20 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val

/** Returns a `Seq` that contains all the elements in this [[TypedDataset]].
*
* Running this [[Job]] requires moving all the data into the application's driver process, and
* Running this operation requires moving all the data into the application's driver process, and
* doing so on a very large [[TypedDataset]] can crash the driver process with OutOfMemoryError.
*
* Differs from `Dataset#collect` by wrapping it's result into a [[Job]].
* Differs from `Dataset#collect` by wrapping it's result into an effect-suspending `F[_]`.
*/
def collect(): Job[Seq[T]] =
Job(dataset.collect())
def collect[F[_]]()(implicit F: SparkDelay[F]): F[Seq[T]] =
F.delay(dataset.collect())

/** Optionally returns the first element in this [[TypedDataset]].
*
* Differs from `Dataset#first` by wrapping it's result into an `Option` and a [[Job]].
* Differs from `Dataset#first` by wrapping it's result into an `Option` and an effect-suspending `F[_]`.
*/
def firstOption(): Job[Option[T]] =
Job {
def firstOption[F[_]]()(implicit F: SparkDelay[F]): F[Option[T]] =
F.delay {
try {
Option(dataset.first())
} catch {
Expand All @@ -214,12 +214,12 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
* Running take requires moving data into the application's driver process, and doing so with
* a very large `num` can crash the driver process with OutOfMemoryError.
*
* Differs from `Dataset#take` by wrapping it's result into a [[Job]].
* Differs from `Dataset#take` by wrapping it's result into an effect-suspending `F[_]`.
*
* apache/spark
*/
def take(num: Int): Job[Seq[T]] =
Job(dataset.take(num))
def take[F[_]](num: Int)(implicit F: SparkDelay[F]): F[Seq[T]] =
F.delay(dataset.take(num))

/** Displays the content of this [[TypedDataset]] in a tabular form. Strings more than 20 characters
* will be truncated, and all cells will be aligned right. For example:
Expand All @@ -235,12 +235,12 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
* @param truncate Whether truncate long strings. If true, strings more than 20 characters will
* be truncated and all cells will be aligned right
*
* Differs from `Dataset#show` by wrapping it's result into a [[Job]].
* Differs from `Dataset#show` by wrapping it's result into an effect-suspending `F[_]`.
*
* apache/spark
*/
def show(numRows: Int = 20, truncate: Boolean = true): Job[Unit] =
Job(dataset.show(numRows, truncate))
def show[F[_]](numRows: Int = 20, truncate: Boolean = true)(implicit F: SparkDelay[F]): F[Unit] =
F.delay(dataset.show(numRows, truncate))

/** Returns a new [[frameless.TypedDataset]] that only contains elements where `column` is `true`.
*
Expand All @@ -258,17 +258,17 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val

/** Runs `func` on each element of this [[TypedDataset]].
*
* Differs from `Dataset#foreach` by wrapping it's result into a [[Job]].
* Differs from `Dataset#foreach` by wrapping it's result into an effect-suspending `F[_]`.
*/
def foreach(func: T => Unit): Job[Unit] =
Job(dataset.foreach(func))
def foreach[F[_]](func: T => Unit)(implicit F: SparkDelay[F]): F[Unit] =
F.delay(dataset.foreach(func))

/** Runs `func` on each partition of this [[TypedDataset]].
*
* Differs from `Dataset#foreachPartition` by wrapping it's result into a [[Job]].
* Differs from `Dataset#foreachPartition` by wrapping it's result into an effect-suspending `F[_]`.
*/
def foreachPartition(func: Iterator[T] => Unit): Job[Unit] =
Job(dataset.foreachPartition(func))
def foreachPartition[F[_]](func: Iterator[T] => Unit)(implicit F: SparkDelay[F]): F[Unit] =
F.delay(dataset.foreachPartition(func))

def groupBy[K1](
c1: TypedColumn[T, K1]
Expand Down
8 changes: 4 additions & 4 deletions dataset/src/main/scala/frameless/TypedDatasetForwarded.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ trait TypedDatasetForwarded[T] { self: TypedDataset[T] =>
deserialized.filter(func)

@deprecated("deserialized methods have moved to a separate section to highlight their runtime overhead", "0.4")
def reduceOption(func: (T, T) => T): Job[Option[T]] =
def reduceOption[F[_]: SparkDelay](func: (T, T) => T): F[Option[T]] =
deserialized.reduceOption(func)
// $COVERAGE-ON$

Expand Down Expand Up @@ -243,10 +243,10 @@ trait TypedDatasetForwarded[T] { self: TypedDataset[T] =>
/** Optionally reduces the elements of this [[TypedDataset]] using the specified binary function. The given
* `func` must be commutative and associative or the result may be non-deterministic.
*
* Differs from `Dataset#reduce` by wrapping it's result into an `Option` and a [[Job]].
* Differs from `Dataset#reduce` by wrapping it's result into an `Option` and an effect-suspending `F`.
*/
def reduceOption(func: (T, T) => T): Job[Option[T]] =
Job {
def reduceOption[F[_]](func: (T, T) => T)(implicit F: SparkDelay[F]): F[Option[T]] =
F.delay {
try {
Option(self.dataset.reduce(func))
} catch {
Expand Down
4 changes: 3 additions & 1 deletion dataset/src/main/scala/frameless/syntax/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
package frameless

package object syntax extends FramelessSyntax
package object syntax extends FramelessSyntax {
implicit val DefaultSparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob
}
2 changes: 2 additions & 0 deletions dataset/src/test/scala/frameless/CollectTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class CollectTests extends TypedDatasetSuite {
}

object CollectTests {
import frameless.syntax._

def prop[A: TypedEncoder : ClassTag](data: Vector[A])(implicit c: SQLContext): Prop =
TypedDataset.create(data).collect().run().toVector ?= data
}
1 change: 1 addition & 0 deletions dataset/src/test/scala/frameless/TypedDatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ class TypedDatasetSuite extends FunSuite with Checkers with BeforeAndAfterAll wi
// Limit size of generated collections and number of checks because Travis
implicit override val generatorDrivenConfig =
PropertyCheckConfiguration(sizeRange = PosZInt(10), minSize = PosZInt(10))
implicit val sparkDelay: SparkDelay[Job] = Job.framelessSparkDelayForJob
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@ import org.scalacheck.Prop
import org.scalacheck.Prop._

class FramelessSyntaxTests extends TypedDatasetSuite {
// Hide the implicit SparkDelay[Job] on TypedDatasetSuite to avoid ambiguous implicits
override val sparkDelay = null

def prop[A, B](data: Vector[X2[A, B]])(
implicit ev: TypedEncoder[X2[A, B]]
): Prop = {
val dataset = TypedDataset.create(data)
val dataset = TypedDataset.create(data).dataset
val dataframe = dataset.toDF()

dataset.collect().run().toVector ?= dataframe.unsafeTyped[X2[A, B]].collect().run().toVector
val typedDataset = dataset.typed
val typedDatasetFromDataFrame = dataframe.unsafeTyped[X2[A, B]]

typedDataset.collect().run().toVector ?= typedDatasetFromDataFrame.collect().run().toVector
}

test("dataset typed - toTyped") {
Expand Down
Loading

0 comments on commit 68aa838

Please sign in to comment.