diff --git a/utils/src/test/scala/com/salesforce/op/utils/spark/OpSparkListenerTest.scala b/utils/src/test/scala/com/salesforce/op/utils/spark/OpSparkListenerTest.scala index d0504d5470..ca67ffb646 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/spark/OpSparkListenerTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/spark/OpSparkListenerTest.scala @@ -32,16 +32,31 @@ package com.salesforce.op.utils.spark import com.salesforce.op.test.TestSparkContext import com.salesforce.op.utils.date.DateTimeUtils +import org.apache.log4j._ import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner +import org.scalatest.prop.TableDrivenPropertyChecks + +import scala.collection.mutable.ArrayBuffer @RunWith(classOf[JUnitRunner]) -class OpSparkListenerTest extends FlatSpec with TestSparkContext { +class OpSparkListenerTest extends FlatSpec with TableDrivenPropertyChecks with TestSparkContext { + val sparkLogAppender: MemoryAppender = { + val sparkAppender = new MemoryAppender() + sparkAppender.setName("spark-appender") + sparkAppender.setThreshold(Level.INFO) + sparkAppender.setLayout(new org.apache.log4j.PatternLayout) + LogManager.getLogger("com.salesforce.op.utils.spark.OpSparkListener").setLevel(Level.INFO) + Logger.getRootLogger.addAppender(sparkAppender) + sparkAppender + } + val start = DateTimeUtils.now().getMillis val listener = new OpSparkListener(sc.appName, sc.applicationId, "testRun", Some("tag"), Some("tagValue"), true, true) sc.addSparkListener(listener) val _ = spark.read.csv(s"$testDataDir/PassengerDataAll.csv").count() + spark.close() Spec[OpSparkListener] should "capture app metrics" in { val appMetrics: AppMetrics = listener.metrics @@ -65,4 +80,53 @@ class OpSparkListenerTest extends FlatSpec with TestSparkContext { firstStage.numTasks shouldBe 1 firstStage.status shouldBe "succeeded" } + + it should "log messages for listener initialization, stage completion, app completion" in { + val firstStage = listener.metrics.stageMetrics.head + val logPrefix = listener.logPrefix + val logs = sparkLogAppender.logs + val messages = Table("Spark Log Messages", + "Instantiated spark listener: com.salesforce.op.utils.spark.OpSparkListener. Log Prefix %s".format(logPrefix), + "%s,APP_TIME_MS:%s".format(logPrefix, listener.metrics.appEndTime - listener.metrics.appStartTime), + "%s,STAGE:%s,MEMORY_SPILLED_BYTES:%s,GC_TIME_MS:%s,STAGE_TIME_MS:%s".format( + logPrefix, firstStage.name, firstStage.memoryBytesSpilled, firstStage.jvmGCTime, firstStage.executorRunTime + ) + ) + + forAll(messages) { m => + logs.map(x => x.getMessage.toString).contains(m) shouldBe true + } + } +} + +/** + * Class to enable in memory logging for tests + */ +class MemoryAppender extends AppenderSkeleton { + private val logRecords = new ArrayBuffer[spi.LoggingEvent] + + override def requiresLayout: Boolean = true + + /** + * Clear out the logRecords in log collection + * @return Unit + */ + override def close(): Unit = { + logRecords.clear + } + + /** + * Add a log to the log collection + * @param event The log event + * @return Unit + */ + override def append(event: spi.LoggingEvent): Unit = { + logRecords.append(event) + } + + /** + * Log event collection + * @return A collection of log events + */ + def logs: ArrayBuffer[spi.LoggingEvent] = logRecords }