Skip to content

Commit

Permalink
Add TLS and simple key-based auth to Dashboard and stop/restart
Browse files Browse the repository at this point in the history
endpoints
  • Loading branch information
Yevgeny committed Mar 2, 2016
1 parent acd9e3f commit 7126623
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 35 deletions.
5 changes: 5 additions & 0 deletions common/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@
// limitations under the License.

name := "common"
libraryDependencies ++= Seq(
"io.spray" %% "spray-can" % "1.3.2",
"io.spray" %% "spray-routing" % "1.3.2",
"org.spark-project.akka" %% "akka-actor" % "2.3.4-spark"
)
11 changes: 11 additions & 0 deletions common/src/main/resources/application.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
akka {
log-config-on-start = false
loggers = ["akka.event.slf4j.Slf4jLogger"]
loglevel = "INFO"
}

spray.can {
server {
verbose-error-messages = "on"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package io.prediction.authentication

/**
* Created by ykhodorkovsky on 3/1/16.
*/
package io.prediction.configuration

/**
* This is a (very) simple authentication for the dashboard and engine servers
* It is highly recommended to implement a stonger authentication mechanism
*/

import java.io.File

import com.typesafe.config.ConfigFactory

import scala.concurrent.ExecutionContext.Implicits.global
import spray.http.HttpRequest
import spray.routing.{AuthenticationFailedRejection, RequestContext}
import spray.routing.authentication._
import spray.routing.directives.AuthMagnet
import scala.concurrent.Future


trait KeyAuthentication {

object ServerKey {
val serverConfig = ConfigFactory.parseFile(new File("conf/server.conf"))

val key = serverConfig.getString("server.accessKey")
def get: String = key
def param: String = "accessKey"
}

def withAccessKeyFromFile: RequestContext => Future[Authentication[HttpRequest]] = {
ctx: RequestContext =>
val accessKeyParamOpt = ctx.request.uri.query.get(ServerKey.param)
Future {

val passedKey = accessKeyParamOpt.getOrElse {
Left(AuthenticationFailedRejection(
AuthenticationFailedRejection.CredentialsRejected, List()))
}

if (passedKey.equals(ServerKey.get)) Right(ctx.request)
else Left(AuthenticationFailedRejection(AuthenticationFailedRejection.CredentialsRejected, List()))

}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package io.prediction.configuration

/**
* Created by ykhodorkovsky on 2/26/16.
*/

import java.io.FileInputStream
import java.io.File
import java.security.KeyStore
import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory}

import com.typesafe.config.ConfigFactory
import spray.io.ServerSSLEngineProvider

trait SSLConfiguration {

private val serverConfig = ConfigFactory.parseFile(new File("conf/server.conf"))

private val keyStoreResource = serverConfig.getString("server.ssl-keystore-resource")
private val password = serverConfig.getString("server.ssl-keystore-pass")
private val keyAlias = serverConfig.getString("server.ssl-key-alias")

private val keyStore = {

//Loading keystore from specified file
val clientStore = KeyStore.getInstance("JKS")
val inputStream = new FileInputStream(keyStoreResource)
clientStore.load(inputStream, password.toCharArray)
inputStream.close()
clientStore
}

//Creating SSL context
implicit def sslContext: SSLContext = {
val context = SSLContext.getInstance("TLS")
val tmf: TrustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
val kmf: KeyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
kmf.init(keyStore, password.toCharArray)
tmf.init(keyStore)
context.init(kmf.getKeyManagers, tmf.getTrustManagers, null)
context
}

//provide implicit SSLEngine with some protocols
implicit def sslEngineProvider: ServerSSLEngineProvider = {
ServerSSLEngineProvider { engine =>
engine.setEnabledCipherSuites(Array("TLS_RSA_WITH_AES_256_CBC_SHA", "TLS_ECDH_ECDSA_WITH_RC4_128_SHA", "TLS_RSA_WITH_AES_128_CBC_SHA"))
engine.setEnabledProtocols(Array("TLSv1", "TLSv1.2", "TLSv1.1"))
engine
}
}
}
Binary file added conf/keystore.jks
Binary file not shown.
9 changes: 9 additions & 0 deletions conf/server.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Engine and dashboard Server related configurations
server {

accessKey = ""

ssl-keystore-resource = "conf/keystore.jks"
ssl-keystore-pass = "pioserver"
ssl-key-alias = "selfsigned"
}
4 changes: 0 additions & 4 deletions core/src/main/resources/application.conf

This file was deleted.

49 changes: 32 additions & 17 deletions core/src/main/scala/io/prediction/workflow/CreateServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ import com.twitter.bijection.Injection
import com.twitter.chill.KryoBase
import com.twitter.chill.KryoInjection
import com.twitter.chill.ScalaKryoInstantiator
import com.typesafe.config.ConfigFactory
import de.javakaffee.kryoserializers.SynchronizedCollectionsSerializer
import grizzled.slf4j.Logging
import io.prediction.authentication.io.prediction.configuration.KeyAuthentication
import io.prediction.configuration.SSLConfiguration
import io.prediction.controller.Engine
import io.prediction.controller.Params
import io.prediction.controller.Utils
Expand All @@ -47,10 +50,12 @@ import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization.write
import spray.can.Http
import spray.can.server.ServerSettings
import spray.http.MediaTypes._
import spray.http._
import spray.httpx.Json4sSupport
import spray.routing._
import spray.routing.authentication.{UserPass, BasicAuth}

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
Expand All @@ -60,6 +65,7 @@ import scala.language.existentials
import scala.util.Failure
import scala.util.Random
import scala.util.Success
import scalaj.http.HttpOptions

class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator {
override def newKryo(): KryoBase = {
Expand Down Expand Up @@ -103,6 +109,7 @@ case class StopServer()
case class ReloadServer()
case class UpgradeCheck()


object CreateServer extends Logging {
val actorSystem = ActorSystem("pio-server")
val engineInstances = Storage.getMetaDataEngineInstances
Expand Down Expand Up @@ -274,23 +281,23 @@ class UpgradeActor(engineClass: String) extends Actor {
}
}

class MasterActor(
class MasterActor (
sc: ServerConfig,
engineInstance: EngineInstance,
engineFactoryName: String,
manifest: EngineManifest) extends Actor {
manifest: EngineManifest) extends Actor with SSLConfiguration with KeyAuthentication {
val log = Logging(context.system, this)
implicit val system = context.system
var sprayHttpListener: Option[ActorRef] = None
var currentServerActor: Option[ActorRef] = None
var retry = 3

def undeploy(ip: String, port: Int): Unit = {
val serverUrl = s"http:https://${ip}:${port}"
val serverUrl = s"https:https://${ip}:${port}"
log.info(
s"Undeploying any existing engine instance at $serverUrl")
try {
val code = scalaj.http.Http(s"$serverUrl/stop").asString.code
val code = scalaj.http.Http(s"$serverUrl/stop").option(HttpOptions.allowUnsafeSSL).param(ServerKey.param, ServerKey.get).method("POST").asString.code
code match {
case 200 => Unit
case 404 => log.error(
Expand Down Expand Up @@ -321,7 +328,8 @@ class MasterActor(
self ! BindServer()
case x: BindServer =>
currentServerActor map { actor =>
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
val settings = ServerSettings(system)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port, settings = Some(settings.copy(sslEncryption = true)))
} getOrElse {
log.error("Cannot bind a non-existing server backend.")
}
Expand All @@ -345,7 +353,8 @@ class MasterActor(
val actor = createServerActor(sc, lr, engineFactoryName, manifest)
sprayHttpListener.map { l =>
l ! Http.Unbind(5.seconds)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port)
val settings = ServerSettings(system)
IO(Http) ! Http.Bind(actor, interface = sc.ip, port = sc.port, settings = Some(settings.copy(sslEncryption = true)))
currentServerActor.get ! Kill
currentServerActor = Some(actor)
} getOrElse {
Expand All @@ -357,7 +366,7 @@ class MasterActor(
s"${manifest.version}. Abort reloading.")
}
case x: Http.Bound =>
val serverUrl = s"http:https://${sc.ip}:${sc.port}"
val serverUrl = s"https:https://${sc.ip}:${sc.port}"
log.info(s"Engine is deployed and running. Engine API is live at ${serverUrl}.")
sprayHttpListener = Some(sender)
case x: Http.CommandFailed =>
Expand Down Expand Up @@ -411,7 +420,7 @@ class ServerActor[Q, P](
val algorithmsParams: Seq[Params],
val models: Seq[Any],
val serving: BaseServing[Q, P],
val servingParams: Params) extends Actor with HttpService {
val servingParams: Params) extends Actor with HttpService with KeyAuthentication {
val serverStartTime = DateTime.now
val log = Logging(context.system, this)

Expand Down Expand Up @@ -458,6 +467,8 @@ class ServerActor[Q, P](
writer.toString
}

private def getUser(up :UserPass) : String = up.user

val myRoute =
path("") {
get {
Expand Down Expand Up @@ -641,20 +652,24 @@ class ServerActor[Q, P](
}
} ~
path("reload") {
get {
complete {
context.actorSelection("/user/master") ! ReloadServer()
"Reloading..."
authenticate(withAccessKeyFromFile) { request =>
post {
complete {
context.actorSelection("/user/master") ! ReloadServer()
"Reloading..."
}
}
}
} ~
path("stop") {
get {
complete {
context.system.scheduler.scheduleOnce(1.seconds) {
context.actorSelection("/user/master") ! StopServer()
authenticate(withAccessKeyFromFile) { request =>
post {
complete {
context.system.scheduler.scheduleOnce(1.seconds) {
context.actorSelection("/user/master") ! StopServer()
}
"Shutting down..."
}
"Shutting down..."
}
}
} ~
Expand Down
3 changes: 2 additions & 1 deletion data/src/main/scala/io/prediction/data/api/EventServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import io.prediction.data.storage.EventJson4sSupport
import io.prediction.data.storage.BatchEventsJson4sSupport
import io.prediction.data.storage.LEvents
import io.prediction.data.storage.Storage
import io.prediction.configuration.SSLConfiguration
import org.json4s.DefaultFormats
import org.json4s.Formats
import org.json4s.JObject
Expand Down Expand Up @@ -556,7 +557,7 @@ class EventServerActor(
val eventClient: LEvents,
val accessKeysClient: AccessKeys,
val channelsClient: Channels,
val config: EventServerConfig) extends Actor {
val config: EventServerConfig) extends Actor with SSLConfiguration {
val log = Logging(context.system, this)
val child = context.actorOf(
Props(classOf[EventServiceActor],
Expand Down
37 changes: 24 additions & 13 deletions tools/src/main/scala/io/prediction/tools/dashboard/Dashboard.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@

package io.prediction.tools.dashboard

import com.typesafe.config.ConfigFactory
import io.prediction.authentication.io.prediction.configuration.KeyAuthentication
import io.prediction.configuration.SSLConfiguration
import io.prediction.data.storage.Storage

import spray.can.server.ServerSettings
import spray.routing.directives.AuthMagnet
import scala.concurrent.{Future, ExecutionContext}
import akka.actor.{ActorContext, Actor, ActorSystem, Props}
import akka.io.IO
import akka.pattern.ask
Expand All @@ -27,14 +32,15 @@ import spray.can.Http
import spray.http._
import spray.http.MediaTypes._
import spray.routing._
import spray.routing.authentication.{Authentication, UserPass, BasicAuth}

import scala.concurrent.duration._

case class DashboardConfig(
ip: String = "localhost",
port: Int = 9000)

object Dashboard extends Logging {
object Dashboard extends Logging with SSLConfiguration{
def main(args: Array[String]): Unit = {
val parser = new scopt.OptionParser[DashboardConfig]("Dashboard") {
opt[String]("ip") action { (x, c) =>
Expand All @@ -55,7 +61,8 @@ object Dashboard extends Logging {
val service =
system.actorOf(Props(classOf[DashboardActor], dc), "dashboard")
implicit val timeout = Timeout(5.seconds)
IO(Http) ? Http.Bind(service, interface = dc.ip, port = dc.port)
val settings = ServerSettings(system)
IO(Http) ? Http.Bind(service, interface = dc.ip, port = dc.port, settings = Some(settings.copy(sslEncryption = true)))
system.awaitTermination
}
}
Expand All @@ -67,22 +74,26 @@ class DashboardActor(
def receive: Actor.Receive = runRoute(dashboardRoute)
}

trait DashboardService extends HttpService with CORSSupport {
trait DashboardService extends HttpService with KeyAuthentication with CORSSupport {

implicit def executionContext: ExecutionContext = actorRefFactory.dispatcher
val dc: DashboardConfig
val evaluationInstances = Storage.getMetaDataEvaluationInstances
val pioEnvVars = sys.env.filter(kv => kv._1.startsWith("PIO_"))
val serverStartTime = DateTime.now
val dashboardRoute =
path("") {
get {
respondWithMediaType(`text/html`) {
complete {
val completedInstances = evaluationInstances.getCompleted
html.index(
dc,
serverStartTime,
pioEnvVars,
completedInstances).toString
authenticate(withAccessKeyFromFile) { request =>
get {
respondWithMediaType(`text/html`) {
complete {
val completedInstances = evaluationInstances.getCompleted
html.index(
dc,
serverStartTime,
pioEnvVars,
completedInstances).toString
}
}
}
}
Expand Down

0 comments on commit 7126623

Please sign in to comment.