From 43adc66c25254e646a9bd91c91c15df76045b5b4 Mon Sep 17 00:00:00 2001 From: detrevid Date: Mon, 23 Mar 2015 01:44:59 +0200 Subject: [PATCH] Add support for serializing synchronized collections --- core/build.sbt | 37 +++++++++--------- .../io/prediction/controller/Utils.scala | 13 +++---- .../io/prediction/workflow/CoreWorkflow.scala | 10 +++-- .../io/prediction/workflow/CreateServer.scala | 39 ++++++++++++------- 4 files changed, 54 insertions(+), 45 deletions(-) diff --git a/core/build.sbt b/core/build.sbt index 6450581d79..2b283754a5 100644 --- a/core/build.sbt +++ b/core/build.sbt @@ -15,25 +15,26 @@ name := "core" libraryDependencies ++= Seq( - "com.github.scopt" %% "scopt" % "3.2.0", - "com.google.code.gson" % "gson" % "2.2.4", - "com.google.guava" % "guava" % "18.0", - "com.twitter" %% "chill" % "0.5.0" + "com.github.scopt" %% "scopt" % "3.2.0", + "com.google.code.gson" % "gson" % "2.2.4", + "com.google.guava" % "guava" % "18.0", + "com.twitter" %% "chill" % "0.5.0" exclude("com.esotericsoftware.minlog", "minlog"), - "com.twitter" %% "chill-bijection" % "0.5.0", - "commons-io" % "commons-io" % "2.4", - "io.spray" %% "spray-can" % "1.3.2", - "io.spray" %% "spray-routing" % "1.3.2", - "net.jodah" % "typetools" % "0.3.1", - "org.apache.spark" %% "spark-core" % sparkVersion.value % "provided", - "org.clapper" %% "grizzled-slf4j" % "1.0.2", - "org.elasticsearch" % "elasticsearch" % elasticsearchVersion.value, - "org.json4s" %% "json4s-native" % json4sVersion.value, - "org.json4s" %% "json4s-ext" % json4sVersion.value, - "org.scalaj" %% "scalaj-http" % "1.1.0", - "org.scalatest" %% "scalatest" % "2.1.6" % "test", - "org.slf4j" % "slf4j-log4j12" % "1.7.7", - "org.specs2" %% "specs2" % "2.3.13" % "test") + "com.twitter" %% "chill-bijection" % "0.5.0", + "de.javakaffee" % "kryo-serializers" % "0.28", + "commons-io" % "commons-io" % "2.4", + "io.spray" %% "spray-can" % "1.3.2", + "io.spray" %% "spray-routing" % "1.3.2", + "net.jodah" % "typetools" % "0.3.1", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "provided", + "org.clapper" %% "grizzled-slf4j" % "1.0.2", + "org.elasticsearch" % "elasticsearch" % elasticsearchVersion.value, + "org.json4s" %% "json4s-native" % json4sVersion.value, + "org.json4s" %% "json4s-ext" % json4sVersion.value, + "org.scalaj" %% "scalaj-http" % "1.1.0", + "org.scalatest" %% "scalatest" % "2.1.6" % "test", + "org.slf4j" % "slf4j-log4j12" % "1.7.7", + "org.specs2" %% "specs2" % "2.3.13" % "test") net.virtualvoid.sbt.graph.Plugin.graphSettings diff --git a/core/src/main/scala/io/prediction/controller/Utils.scala b/core/src/main/scala/io/prediction/controller/Utils.scala index d4f4229423..dda67ac364 100644 --- a/core/src/main/scala/io/prediction/controller/Utils.scala +++ b/core/src/main/scala/io/prediction/controller/Utils.scala @@ -17,7 +17,6 @@ package io.prediction.controller import io.prediction.workflow.KryoInstantiator -import com.twitter.chill.KryoInjection import org.json4s._ import org.json4s.ext.JodaTimeSerializers @@ -42,12 +41,12 @@ object Utils { * @param model Model object. */ def save(id: String, model: Any): Unit = { - val tmpdir = sys.env.get("PIO_FS_TMPDIR").getOrElse( - System.getProperty("java.io.tmpdir")) + val tmpdir = sys.env.getOrElse("PIO_FS_TMPDIR", System.getProperty("java.io.tmpdir")) val modelFile = tmpdir + File.separator + id (new File(tmpdir)).mkdirs val fos = new FileOutputStream(modelFile) - fos.write(KryoInjection(model)) + val kryo = KryoInstantiator.newKryoInjection + fos.write(kryo(model)) fos.close } @@ -59,12 +58,10 @@ object Utils { * @param id Used as the filename of the file. */ def load(id: String): Any = { - val tmpdir = sys.env.get("PIO_FS_TMPDIR").getOrElse( - System.getProperty("java.io.tmpdir")) + val tmpdir = sys.env.getOrElse("PIO_FS_TMPDIR", System.getProperty("java.io.tmpdir")) val modelFile = tmpdir + File.separator + id val src = Source.fromFile(modelFile)(scala.io.Codec.ISO8859) - val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader) - val kryo = KryoInjection.instance(kryoInstantiator) + val kryo = KryoInstantiator.newKryoInjection val m = kryo.invert(src.map(_.toByte).toArray).get src.close m diff --git a/core/src/main/scala/io/prediction/workflow/CoreWorkflow.scala b/core/src/main/scala/io/prediction/workflow/CoreWorkflow.scala index 40edd98838..d5d9f25d24 100644 --- a/core/src/main/scala/io/prediction/workflow/CoreWorkflow.scala +++ b/core/src/main/scala/io/prediction/workflow/CoreWorkflow.scala @@ -15,9 +15,6 @@ package io.prediction.workflow -import com.github.nscala_time.time.Imports.DateTime -import com.twitter.chill.KryoInjection -import grizzled.slf4j.Logger import io.prediction.controller.EngineParams import io.prediction.controller.Evaluation import io.prediction.controller.WorkflowParams @@ -29,6 +26,9 @@ import io.prediction.data.storage.EvaluationInstance import io.prediction.data.storage.Model import io.prediction.data.storage.Storage +import com.github.nscala_time.time.Imports.DateTime +import grizzled.slf4j.Logger + import scala.language.existentials /** CoreWorkflow handles PredictionIO metadata and environment variables of @@ -67,10 +67,12 @@ object CoreWorkflow { val instanceId = Storage.getMetaDataEngineInstances + val kryo = KryoInstantiator.newKryoInjection + logger.info("Inserting persistent model") Storage.getModelDataModels.insert(Model( id = engineInstance.id, - models = KryoInjection(models))) + models = kryo(models))) logger.info("Updating engine instance") val engineInstances = Storage.getMetaDataEngineInstances diff --git a/core/src/main/scala/io/prediction/workflow/CreateServer.scala b/core/src/main/scala/io/prediction/workflow/CreateServer.scala index 7745f30c7c..7269752f5a 100644 --- a/core/src/main/scala/io/prediction/workflow/CreateServer.scala +++ b/core/src/main/scala/io/prediction/workflow/CreateServer.scala @@ -15,17 +15,6 @@ package io.prediction.workflow -import akka.actor._ -import akka.event.Logging -import akka.io.IO -import akka.pattern.ask -import akka.util.Timeout -import com.github.nscala_time.time.Imports.DateTime -import com.google.gson.Gson -import com.twitter.chill.KryoBase -import com.twitter.chill.KryoInjection -import com.twitter.chill.ScalaKryoInstantiator -import grizzled.slf4j.Logging import io.prediction.controller.Engine import io.prediction.controller.Params import io.prediction.controller.Utils @@ -40,6 +29,20 @@ import io.prediction.core.Doer import io.prediction.data.storage.EngineInstance import io.prediction.data.storage.EngineManifest import io.prediction.data.storage.Storage + +import akka.actor._ +import akka.event.Logging +import akka.io.IO +import akka.pattern.ask +import akka.util.Timeout +import com.github.nscala_time.time.Imports.DateTime +import com.google.gson.Gson +import com.twitter.bijection.Injection +import com.twitter.chill.KryoBase +import com.twitter.chill.KryoInjection +import com.twitter.chill.ScalaKryoInstantiator +import de.javakaffee.kryoserializers.SynchronizedCollectionsSerializer +import grizzled.slf4j.Logging import org.json4s._ import org.json4s.native.JsonMethods._ import org.json4s.native.Serialization.write @@ -57,17 +60,24 @@ import scala.util.Failure import scala.util.Random import scala.util.Success -import java.io.PrintWriter -import java.io.StringWriter +import java.io.{Serializable, PrintWriter, StringWriter} class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator { override def newKryo(): KryoBase = { val kryo = super.newKryo() kryo.setClassLoader(classLoader) + SynchronizedCollectionsSerializer.registerSerializers(kryo) kryo } } +object KryoInstantiator extends Serializable { + def newKryoInjection : Injection[Any, Array[Byte]] = { + val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader) + KryoInjection.instance(kryoInstantiator) + } +} + case class ServerConfig( batch: String = "", engineInstanceId: String = "", @@ -190,8 +200,7 @@ object CreateServer extends Logging { val engineParams = engine.engineInstanceToEngineParams(engineInstance) - val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader) - val kryo = KryoInjection.instance(kryoInstantiator) + val kryo = KryoInstantiator.newKryoInjection val modelsFromEngineInstance = kryo.invert(modeldata.get(engineInstance.id).get.models).get.