Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gerashegalov committed Jul 30, 2019
1 parent 23b6e91 commit 6694415
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,23 @@

package com.salesforce.op.stages.impl.insights

import com.salesforce.op.{FeatureInsights, UID}
import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.stages.impl.feature.{DateToUnitCircle, TimePeriod}
import com.salesforce.op.stages.impl.feature.TimePeriod
import com.salesforce.op.stages.impl.selector.SelectedModel
import com.salesforce.op.stages.sparkwrappers.specific.OpPredictorWrapperModel
import com.salesforce.op.stages.sparkwrappers.specific.SparkModelConverter._
import com.salesforce.op.utils.spark.{OpVectorColumnHistory, OpVectorMetadata}
import enumeratum.{Enum, EnumEntry}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Model
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.linalg.{DenseVector, Vectors}
import org.apache.spark.ml.param.{IntParam, Param, Params}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

import scala.reflect.runtime.universe._


trait RecordInsightsLOCOParams extends Params {

Expand Down Expand Up @@ -132,17 +130,15 @@ class RecordInsightsLOCO[T <: Model[T]]

private def computeDiffs
(
i: Int,
oldInd: Int,
oldVal: Double,
featureArray: Array[(Int, Double)],
featureSize: Int,
featureVec: DenseVector,
baseScore: Array[Double]
): Array[Double] = {
featureArray.update(i, (oldInd, 0.0))
val score = modelApply(labelDummy, Vectors.sparse(featureSize, featureArray).toOPVector).score
featureVec.values.update(oldInd, 0.0)
val score = modelApply(labelDummy, featureVec.toOPVector).score
val diffs = baseScore.zip(score).map { case (b, s) => b - s }
featureArray.update(i, (oldInd, oldVal))
featureVec.values.update(oldInd, oldVal)
diffs
}

Expand All @@ -165,38 +161,38 @@ class RecordInsightsLOCO[T <: Model[T]]

private def returnTopPosNeg
(
featureArray: Array[(Int, Double)],
featureSize: Int,
featureVec: DenseVector,
baseScore: Array[Double],
k: Int,
indexToExamine: Int
): Seq[LOCOValue] = {

val minMaxHeap = new MinMaxHeap(k)
val aggregationMap = mutable.Map.empty[String, (Array[Int], Array[Double])]
for {i <- featureArray.indices} {
val (oldInd, oldVal) = featureArray(i)
val diffToExamine = computeDiffs(i, oldInd, oldVal, featureArray, featureSize, baseScore)
val history = histories(oldInd)

history match {
// If indicator value and descriptor value of a derived text feature are empty, then it is likely
// to be a hashing tf output. We aggregate such features for each (rawFeatureName).
case h if h.indicatorValue.isEmpty && h.descriptorValue.isEmpty && textFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val (indices, array) = aggregationMap.getOrElse(name, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(name, (indices :+ i, sumArrays(array, diffToExamine)))
}
// If the descriptor value of a derived date feature exists, then it is likely to be
// from unit circle transformer. We aggregate such features for each (rawFeatureName, timePeriod).
case h if h.descriptorValue.isDefined && dateFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val key = name + h.descriptorValue.flatMap(convertToTimePeriod).map(p => "_" + p.entryName).getOrElse("")
val (indices, array) = aggregationMap.getOrElse(key, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(key, (indices :+ i, sumArrays(array, diffToExamine)))
}
case _ => minMaxHeap enqueue LOCOValue(i, diffToExamine(indexToExamine), diffToExamine)
}
featureVec.foreachActive {
case (oldInd, oldVal) if oldVal != 0.0 =>
val diffToExamine = computeDiffs(oldInd, oldVal, featureVec, baseScore)
val history = histories(oldInd)

history match {
// If indicator value and descriptor value of a derived text feature are empty, then it is likely
// to be a hashing tf output. We aggregate such features for each (rawFeatureName).
case h if h.indicatorValue.isEmpty && h.descriptorValue.isEmpty && textFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val (indices, array) = aggregationMap.getOrElse(name, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(name, (indices :+ oldInd, sumArrays(array, diffToExamine)))
}
// If the descriptor value of a derived date feature exists, then it is likely to be
// from unit circle transformer. We aggregate such features for each (rawFeatureName, timePeriod).
case h if h.descriptorValue.isDefined && dateFeatureIndices.contains(oldInd) =>
for {name <- getRawFeatureName(h)} {
val key = name + h.descriptorValue.flatMap(convertToTimePeriod).map(p => "_" + p.entryName).getOrElse("")
val (indices, array) = aggregationMap.getOrElse(key, (Array.empty[Int], Array.empty[Double]))
aggregationMap.update(key, (indices :+ oldInd, sumArrays(array, diffToExamine)))
}
case _ => minMaxHeap enqueue LOCOValue(oldInd, diffToExamine(indexToExamine), diffToExamine)
}
case _ => ()
}

// Adding LOCO results from aggregation map into heaps
Expand All @@ -222,6 +218,7 @@ class RecordInsightsLOCO[T <: Model[T]]
(textFeatureIndices ++ dateFeatureIndices).foreach(i => if (!featuresSparse.indices.contains(i)) res += i -> 0.0)
val featureArray = res.toArray
val featureSize = featuresSparse.size
val featureDense = Vectors.sparse(featureSize, featureArray).toDense

val k = $(topK)
// Index where to examine the difference in the prediction vector
Expand All @@ -232,7 +229,7 @@ class RecordInsightsLOCO[T <: Model[T]]
// For MultiClassification, the value is from the predicted class(i.e. the class having the highest probability)
case n if n > 2 => baseResult.prediction.toInt
}
val topPosNeg = returnTopPosNeg(featureArray, featureSize, baseScore, k, indexToExamine)
val topPosNeg = returnTopPosNeg(featureDense, baseScore, k, indexToExamine)
val top = getTopKStrategy match {
case TopKStrategy.Abs => topPosNeg.sortBy { case LOCOValue(_, v, _) => -math.abs(v) }.take(k)
// Take top K positive and top K negative LOCOs, hence 2 * K
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=0.6.1-SNAPSHOT
version=0.6.1-SNAPSHOT-gera-1
group=com.salesforce.transmogrifai
org.gradle.caching=true

0 comments on commit 6694415

Please sign in to comment.