Skip to content

Commit

Permalink
Merge pull request apache#3 from detrevid/develop
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
k4hoo committed May 4, 2015
2 parents ac558ca + 415ad37 commit 44831f7
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 13 deletions.
7 changes: 2 additions & 5 deletions src/main/scala/DataSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ package org.template.classification

import io.prediction.controller.PDataSource
import io.prediction.controller.EmptyEvaluationInfo
import io.prediction.controller.EmptyActualResult
import io.prediction.controller.Params
import io.prediction.data.storage.Event
import io.prediction.data.store.PEventStore

import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
Expand Down Expand Up @@ -60,7 +57,7 @@ class DataSource(val dsp: DataSourceParams)
override
def readEval(sc: SparkContext)
: Seq[(TrainingData, EmptyEvaluationInfo, RDD[(Query, ActualResult)])] = {
require(!dsp.evalK.isEmpty, "DataSourceParams.evalK must not be None")
require(dsp.evalK.nonEmpty, "DataSourceParams.evalK must not be None")

// The following code reads the data from data store. It is equivalent to
// the readTraining method. We copy-and-paste the exact code here for
Expand Down Expand Up @@ -95,7 +92,7 @@ class DataSource(val dsp: DataSourceParams)

// K-fold splitting
val evalK = dsp.evalK.get
val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex
val indexedPoints: RDD[(LabeledPoint, Long)] = labeledPoints.zipWithIndex()

(0 until evalK).map { idx =>
val trainingPoints = indexedPoints.filter(_._2 % evalK != idx).map(_._1)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/Engine.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package org.template.classification

import io.prediction.controller.IEngineFactory
import io.prediction.controller.EngineFactory
import io.prediction.controller.Engine

class Query(
Expand All @@ -15,7 +15,7 @@ class ActualResult(
val label: Double
) extends Serializable

object ClassificationEngine extends IEngineFactory {
object ClassificationEngine extends EngineFactory {
def apply() = {
new Engine(
classOf[DataSource],
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/Evaluation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import io.prediction.controller.EngineParams
import io.prediction.controller.EngineParamsGenerator
import io.prediction.controller.Evaluation

case class Accuracy
case class Accuracy()
extends AverageMetric[EmptyEvaluationInfo, Query, PredictedResult, ActualResult] {
def calculate(query: Query, predicted: PredictedResult, actual: ActualResult)
: Double = (if (predicted.label == actual.label) 1.0 else 0.0)
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/NaiveBayesAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class NaiveBayesAlgorithm(val ap: AlgorithmParams)

def train(sc: SparkContext, data: PreparedData): NaiveBayesModel = {
// MLLib NaiveBayes cannot handle empty training data.
require(!data.labeledPoints.take(1).isEmpty,
s"RDD[labeldPoints] in PreparedData cannot be empty." +
require(data.labeledPoints.take(1).nonEmpty,
s"RDD[labeledPoints] in PreparedData cannot be empty." +
" Please check if DataSource generates TrainingData" +
" and Preprator generates PreparedData correctly.")
" and Preparator generates PreparedData correctly.")

NaiveBayes.train(data.labeledPoints, ap.lambda)
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/PrecisionEvaluation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ case class Precision(label: Double)
Some(0.0) // False positive
}
} else {
None // Unrelated case for calcuating precision
None // Unrelated case for calculating precision
}
}
}
Expand Down
1 change: 0 additions & 1 deletion src/main/scala/Preparator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package org.template.classification
import io.prediction.controller.PPreparator

import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint

Expand Down

0 comments on commit 44831f7

Please sign in to comment.