From 69354930086788ab4611e1e5a567cad55e9d2099 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 20 Jun 2018 23:45:51 -0700 Subject: [PATCH 01/19] Added facilities for pretty printing tables --- .../com/salesforce/op/utils/table/Table.scala | 150 ++++++++++++++++ .../salesforce/op/utils/table/TableTest.scala | 166 ++++++++++++++++++ 2 files changed, 316 insertions(+) create mode 100644 utils/src/main/scala/com/salesforce/op/utils/table/Table.scala create mode 100644 utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala new file mode 100644 index 0000000000..e8f9d72ca8 --- /dev/null +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.utils.table + +import com.twitter.algebird.Operators._ +import com.twitter.algebird.{Monoid, Semigroup} +import enumeratum._ + + +/** + * Simple table representation consisting of rows, i.e: + * + * +----------------------------------------+ + * | Transactions | + * +----------------------------------------+ + * | date | amount | source | status | + * +------+--------+--------------+---------+ + * | 1 | 4.95 | Cafe Venetia | Success | + * | 2 | 12.65 | Sprout | Success | + * | 3 | 4.75 | Caltrain | Pending | + * +------+--------+--------------+---------+ + * + * @param columns non empty sequence of column names + * @param rows non empty sequence of rows + * @param name table name + * @tparam T row type + */ +/** + * Simple table representation consisting of rows, i.e: + * + * +----------------------------------------+ + * | Transactions | + * +----------------------------------------+ + * | date | amount | source | status | + * +------+--------+--------------+---------+ + * | 1 | 4.95 | Cafe Venetia | Success | + * | 2 | 12.65 | Sprout | Success | + * | 3 | 4.75 | Caltrain | Pending | + * +------+--------+--------------+---------+ + * + * @param columns non empty sequence of column names + * @param rows non empty sequence of rows + * @param name table name + * @param nameAlignment table name alignment when printing + * @param columnAlignments column name & values alignment when printing + * (if not set defaults to [[defaultColumnAlignment]]) + * @param defaultColumnAlignment default column name & values alignment when printing + * @tparam T row type + */ +case class Table[T <: Product]( + columns: Seq[String], + rows: Seq[T], + name: String = "", + nameAlignment: Alignment = Alignment.Center, + columnAlignments: Map[String, Alignment] = Map.empty, + defaultColumnAlignment: Alignment = Alignment.Left +) { + require(columns.nonEmpty, "columns cannot be empty") + require(rows.nonEmpty, "rows cannot be empty") + require(columns.length == rows.head.productArity, + s"columns length must match rows arity (${columns.length}!=${rows.head.productArity})") + + private implicit val max = Semigroup.from[Int](math.max) + private implicit val monoid: Monoid[Array[Int]] = Monoid.arrayMonoid[Int] + + private def formatCell(v: String, size: Int, sep: String, fill: String): PartialFunction[Alignment, String] = { + case Alignment.Left => v + fill * (size - v.length) + case Alignment.Right => fill * (size - v.length) + v + case Alignment.Center => + String.format("%-" + size + "s", String.format("%" + (v.length + (size - v.length) / 2) + "s", v)) + } + + private def formatRow( + values: Iterable[String], + cellSizes: Iterable[Int], + alignment: String => Alignment = columnAlignments.getOrElse(_, defaultColumnAlignment), + sep: String = "|", + fill: String = " " + ): String = { + val formatted = values.zipWithIndex.zip(cellSizes).map { case ((v, i), size) => + formatCell(v, size, sep, fill)(alignment(columns(i))) + } + formatted.mkString(s"$sep$fill", s"$fill$sep$fill", s"$fill$sep") + } + + /** + * Pretty print table + * + * @return pretty printed table + */ + def prettyString: String = { + val rowVals= rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) + val columnSizes = columns.map(c => math.max(1, c.length)).toArray + val cellSizes = rowVals.map(_.map(_.length).toArray).foldLeft(columnSizes)(_ + _) + val bracket = formatRow(Seq.fill(cellSizes.length)(""), cellSizes, _ => Alignment.Left, sep = "+", fill = "-") + val rowWidth = bracket.length - 4 // cellSizes.sum + 2 * cellSizes.length + 1 + val cleanBracket = formatRow(Seq(""), Seq(rowWidth), _ => Alignment.Left, sep = "+", fill = "-") + val maybeName = Option(name) match { + case Some(n) if n.nonEmpty => Seq(cleanBracket, formatRow(Seq(name), Seq(rowWidth), _ => nameAlignment)) + case _ => Seq.empty + } + val columnsHeader = formatRow(columns, cellSizes) + val formattedRows = rowVals.map(formatRow(_, cellSizes)) + + (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") + } + + + + override def toString: String = prettyString + +} + +sealed trait Alignment extends EnumEntry +object Alignment extends Enum[Alignment] { + val values = findValues + case object Left extends Alignment + case object Right extends Alignment + case object Center extends Alignment +} + diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala new file mode 100644 index 0000000000..e47af45639 --- /dev/null +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -0,0 +1,166 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.utils.table + +import com.salesforce.op.test.TestCommon +import com.salesforce.op.utils.table.Alignment._ +import org.junit.runner.RunWith +import org.scalatest.FlatSpec +import org.scalatest.junit.JUnitRunner + +case class Transaction(date: Long, amount: Double, source: String, status: String) + +@RunWith(classOf[JUnitRunner]) +class TableTest extends FlatSpec with TestCommon { + + val columns = Seq("date", "amount", "source", "status") + val transactions = Seq( + Transaction(1, 4.95, "Cafe Venetia", "Success"), + Transaction(2, 12.65, "Sprout", "Success"), + Transaction(3, 4.75, "Caltrain", "Pending") + ) + + Spec[Table[_]] should "error on missing columns" in { + intercept[IllegalArgumentException] { + Table(columns = Seq.empty, rows = transactions) + }.getMessage shouldBe "requirement failed: columns cannot be empty" + } + it should "error on empty rows" in { + intercept[IllegalArgumentException] { + Table(columns = columns, rows = Seq.empty[Transaction]) + }.getMessage shouldBe "requirement failed: rows cannot be empty" + } + it should "error on invalid arity" in { + intercept[IllegalArgumentException] { + Table(columns = Seq("a"), rows = transactions) + }.getMessage shouldBe "requirement failed: columns length must match rows arity (1!=4)" + } + it should "pretty print a table" in { + Table(columns = columns, rows = transactions).prettyString shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "have a pretty toString as well" in { + val table = Table(columns = columns, rows = transactions) + table.prettyString shouldBe table.toString + } + it should "pretty print a table with a name" in { + Table(columns = columns, rows = transactions, name = "Transactions").prettyString shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with a name aligned left" in { + Table(columns = columns, rows = transactions, name = "Transactions", nameAlignment = Left).prettyString shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with right column alignment" in { + Table(columns = columns, rows = transactions, defaultColumnAlignment = Right).prettyString shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with center column alignment" in { + Table(columns = columns, rows = transactions, defaultColumnAlignment = Center).prettyString shouldBe + """|+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table with custom column alignment" in { + Table(columns = columns, rows = transactions, name = "Transactions", + nameAlignment = Center, defaultColumnAlignment = Right, + columnAlignments = Map("date" -> Right, "amount" -> Left, "status" -> Center) + ).prettyString shouldBe + """|+----------------------------------------+ + || Transactions | + |+----------------------------------------+ + || date | amount | source | status | + |+------+--------+--------------+---------+ + || 1 | 4.95 | Cafe Venetia | Success | + || 2 | 12.65 | Sprout | Success | + || 3 | 4.75 | Caltrain | Pending | + |+------+--------+--------------+---------+""".stripMargin + } + it should "pretty print a table even if data is bad" in { + val badData1 = Seq(Tuple2(null, "one"), "2" -> "", (null, null), "3" -> Transaction(1, 1.0, "?", "?")) + Table(columns = Seq("c1", "c2"), rows = badData1, name = "Bad Data").prettyString shouldBe + """|+-----------------------------+ + || Bad Data | + |+-----------------------------+ + || c1 | c2 | + |+----+------------------------+ + || | one | + || 2 | | + || | | + || 3 | Transaction(1,1.0,?,?) | + |+----+------------------------+""".stripMargin + } + it should "pretty print a table even if data is really bad" in { + val badData2 = Seq(null, "", 1).map(Tuple1(_)) + Table(columns = Seq(""), rows = badData2).prettyString shouldBe + """|+---+ + || | + |+---+ + || | + || | + || 1 | + |+---+""".stripMargin + } + +} From 8eaf97ab1eca74a9fa71405ad87d72cc4dd87acb Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 20 Jun 2018 23:46:35 -0700 Subject: [PATCH 02/19] cleanup --- .../src/main/scala/com/salesforce/op/utils/table/Table.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala index e8f9d72ca8..1ab06cd8f5 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -122,7 +122,7 @@ case class Table[T <: Product]( val columnSizes = columns.map(c => math.max(1, c.length)).toArray val cellSizes = rowVals.map(_.map(_.length).toArray).foldLeft(columnSizes)(_ + _) val bracket = formatRow(Seq.fill(cellSizes.length)(""), cellSizes, _ => Alignment.Left, sep = "+", fill = "-") - val rowWidth = bracket.length - 4 // cellSizes.sum + 2 * cellSizes.length + 1 + val rowWidth = bracket.length - 4 val cleanBracket = formatRow(Seq(""), Seq(rowWidth), _ => Alignment.Left, sep = "+", fill = "-") val maybeName = Option(name) match { case Some(n) if n.nonEmpty => Seq(cleanBracket, formatRow(Seq(name), Seq(rowWidth), _ => nameAlignment)) @@ -134,8 +134,6 @@ case class Table[T <: Product]( (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") } - - override def toString: String = prettyString } From 21e010bfbed4666b6deb1779c3457c839d45bd86 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 20 Jun 2018 23:49:22 -0700 Subject: [PATCH 03/19] cleanup2 --- .../com/salesforce/op/utils/table/Table.scala | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala index 1ab06cd8f5..40dbffc5ee 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -35,25 +35,6 @@ import com.twitter.algebird.Operators._ import com.twitter.algebird.{Monoid, Semigroup} import enumeratum._ - -/** - * Simple table representation consisting of rows, i.e: - * - * +----------------------------------------+ - * | Transactions | - * +----------------------------------------+ - * | date | amount | source | status | - * +------+--------+--------------+---------+ - * | 1 | 4.95 | Cafe Venetia | Success | - * | 2 | 12.65 | Sprout | Success | - * | 3 | 4.75 | Caltrain | Pending | - * +------+--------+--------------+---------+ - * - * @param columns non empty sequence of column names - * @param rows non empty sequence of rows - * @param name table name - * @tparam T row type - */ /** * Simple table representation consisting of rows, i.e: * From ff75ae993a6a55752b2b7bc3439d39d012ca3bbe Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 20 Jun 2018 23:54:35 -0700 Subject: [PATCH 04/19] cleanup3 --- .../com/salesforce/op/utils/table/Table.scala | 36 +++++++++---------- .../salesforce/op/utils/table/TableTest.scala | 24 ++++++------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala index 40dbffc5ee..1726630a2a 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -49,22 +49,11 @@ import enumeratum._ * +------+--------+--------------+---------+ * * @param columns non empty sequence of column names - * @param rows non empty sequence of rows - * @param name table name - * @param nameAlignment table name alignment when printing - * @param columnAlignments column name & values alignment when printing - * (if not set defaults to [[defaultColumnAlignment]]) - * @param defaultColumnAlignment default column name & values alignment when printing + * @param rows non empty sequence of rows + * @param name table name * @tparam T row type */ -case class Table[T <: Product]( - columns: Seq[String], - rows: Seq[T], - name: String = "", - nameAlignment: Alignment = Alignment.Center, - columnAlignments: Map[String, Alignment] = Map.empty, - defaultColumnAlignment: Alignment = Alignment.Left -) { +case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String = "") { require(columns.nonEmpty, "columns cannot be empty") require(rows.nonEmpty, "rows cannot be empty") require(columns.length == rows.head.productArity, @@ -83,7 +72,7 @@ case class Table[T <: Product]( private def formatRow( values: Iterable[String], cellSizes: Iterable[Int], - alignment: String => Alignment = columnAlignments.getOrElse(_, defaultColumnAlignment), + alignment: String => Alignment, sep: String = "|", fill: String = " " ): String = { @@ -96,10 +85,17 @@ case class Table[T <: Product]( /** * Pretty print table * + * @param nameAlignment table name alignment + * @param columnAlignments column name & values alignment + * @param defaultColumnAlignment default column name & values alignment * @return pretty printed table */ - def prettyString: String = { - val rowVals= rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) + def prettyString( + nameAlignment: Alignment = Alignment.Center, + columnAlignments: Map[String, Alignment] = Map.empty, + defaultColumnAlignment: Alignment = Alignment.Left + ): String = { + val rowVals = rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) val columnSizes = columns.map(c => math.max(1, c.length)).toArray val cellSizes = rowVals.map(_.map(_.length).toArray).foldLeft(columnSizes)(_ + _) val bracket = formatRow(Seq.fill(cellSizes.length)(""), cellSizes, _ => Alignment.Left, sep = "+", fill = "-") @@ -109,13 +105,13 @@ case class Table[T <: Product]( case Some(n) if n.nonEmpty => Seq(cleanBracket, formatRow(Seq(name), Seq(rowWidth), _ => nameAlignment)) case _ => Seq.empty } - val columnsHeader = formatRow(columns, cellSizes) - val formattedRows = rowVals.map(formatRow(_, cellSizes)) + val columnsHeader = formatRow(columns, cellSizes, columnAlignments.getOrElse(_, defaultColumnAlignment)) + val formattedRows = rowVals.map(formatRow(_, cellSizes, columnAlignments.getOrElse(_, defaultColumnAlignment))) (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") } - override def toString: String = prettyString + override def toString: String = prettyString() } diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala index e47af45639..83dc24d457 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -65,7 +65,7 @@ class TableTest extends FlatSpec with TestCommon { }.getMessage shouldBe "requirement failed: columns length must match rows arity (1!=4)" } it should "pretty print a table" in { - Table(columns = columns, rows = transactions).prettyString shouldBe + Table(columns = columns, rows = transactions).prettyString() shouldBe """|+----------------------------------------+ || date | amount | source | status | |+------+--------+--------------+---------+ @@ -76,10 +76,10 @@ class TableTest extends FlatSpec with TestCommon { } it should "have a pretty toString as well" in { val table = Table(columns = columns, rows = transactions) - table.prettyString shouldBe table.toString + table.prettyString() shouldBe table.toString } it should "pretty print a table with a name" in { - Table(columns = columns, rows = transactions, name = "Transactions").prettyString shouldBe + Table(columns = columns, rows = transactions, name = "Transactions").prettyString() shouldBe """|+----------------------------------------+ || Transactions | |+----------------------------------------+ @@ -91,7 +91,7 @@ class TableTest extends FlatSpec with TestCommon { |+------+--------+--------------+---------+""".stripMargin } it should "pretty print a table with a name aligned left" in { - Table(columns = columns, rows = transactions, name = "Transactions", nameAlignment = Left).prettyString shouldBe + Table(columns = columns, rows = transactions, name = "Transactions").prettyString(nameAlignment = Left) shouldBe """|+----------------------------------------+ || Transactions | |+----------------------------------------+ @@ -103,7 +103,7 @@ class TableTest extends FlatSpec with TestCommon { |+------+--------+--------------+---------+""".stripMargin } it should "pretty print a table with right column alignment" in { - Table(columns = columns, rows = transactions, defaultColumnAlignment = Right).prettyString shouldBe + Table(columns = columns, rows = transactions).prettyString(defaultColumnAlignment = Right) shouldBe """|+----------------------------------------+ || date | amount | source | status | |+------+--------+--------------+---------+ @@ -113,7 +113,7 @@ class TableTest extends FlatSpec with TestCommon { |+------+--------+--------------+---------+""".stripMargin } it should "pretty print a table with center column alignment" in { - Table(columns = columns, rows = transactions, defaultColumnAlignment = Center).prettyString shouldBe + Table(columns = columns, rows = transactions).prettyString(defaultColumnAlignment = Center) shouldBe """|+----------------------------------------+ || date | amount | source | status | |+------+--------+--------------+---------+ @@ -123,10 +123,10 @@ class TableTest extends FlatSpec with TestCommon { |+------+--------+--------------+---------+""".stripMargin } it should "pretty print a table with custom column alignment" in { - Table(columns = columns, rows = transactions, name = "Transactions", - nameAlignment = Center, defaultColumnAlignment = Right, - columnAlignments = Map("date" -> Right, "amount" -> Left, "status" -> Center) - ).prettyString shouldBe + Table(columns = columns, rows = transactions, name = "Transactions") + .prettyString( + nameAlignment = Center, defaultColumnAlignment = Right, + columnAlignments = Map("date" -> Right, "amount" -> Left, "status" -> Center)) shouldBe """|+----------------------------------------+ || Transactions | |+----------------------------------------+ @@ -139,7 +139,7 @@ class TableTest extends FlatSpec with TestCommon { } it should "pretty print a table even if data is bad" in { val badData1 = Seq(Tuple2(null, "one"), "2" -> "", (null, null), "3" -> Transaction(1, 1.0, "?", "?")) - Table(columns = Seq("c1", "c2"), rows = badData1, name = "Bad Data").prettyString shouldBe + Table(columns = Seq("c1", "c2"), rows = badData1, name = "Bad Data").prettyString() shouldBe """|+-----------------------------+ || Bad Data | |+-----------------------------+ @@ -153,7 +153,7 @@ class TableTest extends FlatSpec with TestCommon { } it should "pretty print a table even if data is really bad" in { val badData2 = Seq(null, "", 1).map(Tuple1(_)) - Table(columns = Seq(""), rows = badData2).prettyString shouldBe + Table(columns = Seq(""), rows = badData2).prettyString() shouldBe """|+---+ || | |+---+ From 7bb4f16744ff95797d50ecd7a709741972370698 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Wed, 20 Jun 2018 23:57:55 -0700 Subject: [PATCH 05/19] cleanup4 --- .../src/main/scala/com/salesforce/op/utils/table/Table.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala index 1726630a2a..9cce4f3dbf 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -105,8 +105,9 @@ case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String case Some(n) if n.nonEmpty => Seq(cleanBracket, formatRow(Seq(name), Seq(rowWidth), _ => nameAlignment)) case _ => Seq.empty } - val columnsHeader = formatRow(columns, cellSizes, columnAlignments.getOrElse(_, defaultColumnAlignment)) - val formattedRows = rowVals.map(formatRow(_, cellSizes, columnAlignments.getOrElse(_, defaultColumnAlignment))) + val alignment: String => Alignment = columnAlignments.getOrElse(_, defaultColumnAlignment) + val columnsHeader = formatRow(columns, cellSizes, alignment) + val formattedRows = rowVals.map(formatRow(_, cellSizes, alignment)) (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") } From 153b3b2ed77c51ae0953280d6873aa8f12bf2daf Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 00:22:34 -0700 Subject: [PATCH 06/19] make scalastyle happy --- .../test/scala/com/salesforce/op/utils/table/TableTest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala index 83dc24d457..4faa78a903 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -42,6 +42,8 @@ case class Transaction(date: Long, amount: Double, source: String, status: Strin @RunWith(classOf[JUnitRunner]) class TableTest extends FlatSpec with TestCommon { + // scalastyle:off indentation + val columns = Seq("date", "amount", "source", "status") val transactions = Seq( Transaction(1, 4.95, "Cafe Venetia", "Success"), From dccea200a8cdf4cb8292f7cfbae9c37fb0c01733 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 07:45:31 -0700 Subject: [PATCH 07/19] allow sorting columns --- .../com/salesforce/op/utils/table/Table.scala | 60 ++++++++++++++----- .../salesforce/op/utils/table/TableTest.scala | 22 ++++++- 2 files changed, 67 insertions(+), 15 deletions(-) diff --git a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala index 9cce4f3dbf..2d0fc817f3 100644 --- a/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala +++ b/utils/src/main/scala/com/salesforce/op/utils/table/Table.scala @@ -31,10 +31,32 @@ package com.salesforce.op.utils.table -import com.twitter.algebird.Operators._ import com.twitter.algebird.{Monoid, Semigroup} import enumeratum._ + +object Table { + /** + * Simple factory for creating table instance with rows of [[Product]] types + * + * @param columns non empty sequence of column names + * @param rows non empty sequence of rows + * @param name table name + * @tparam T row type of [[Product]] + */ + def apply[T <: Product](columns: Seq[String], rows: Seq[T], name: String = ""): Table = { + require(columns.nonEmpty, "columns cannot be empty") + require(rows.nonEmpty, "rows cannot be empty") + require(columns.length == rows.head.productArity, + s"columns length must match rows arity (${columns.length}!=${rows.head.productArity})") + val rowVals = rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) + new Table(columns, rowVals, name) + } + + private implicit val max = Semigroup.from[Int](math.max) + private implicit val monoid: Monoid[Array[Int]] = Monoid.arrayMonoid[Int] +} + /** * Simple table representation consisting of rows, i.e: * @@ -51,17 +73,8 @@ import enumeratum._ * @param columns non empty sequence of column names * @param rows non empty sequence of rows * @param name table name - * @tparam T row type */ -case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String = "") { - require(columns.nonEmpty, "columns cannot be empty") - require(rows.nonEmpty, "rows cannot be empty") - require(columns.length == rows.head.productArity, - s"columns length must match rows arity (${columns.length}!=${rows.head.productArity})") - - private implicit val max = Semigroup.from[Int](math.max) - private implicit val monoid: Monoid[Array[Int]] = Monoid.arrayMonoid[Int] - +class Table private(columns: Seq[String], rows: Seq[Seq[String]], name: String) { private def formatCell(v: String, size: Int, sep: String, fill: String): PartialFunction[Alignment, String] = { case Alignment.Left => v + fill * (size - v.length) case Alignment.Right => fill * (size - v.length) + v @@ -82,6 +95,26 @@ case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String formatted.mkString(s"$sep$fill", s"$fill$sep$fill", s"$fill$sep") } + private def sortColumns(ascending: Boolean): Table = { + val (columnsSorted, indices) = columns.zipWithIndex.sortBy(_._1).unzip + val rowsSorted = rows.map(row => row.zip(indices).sortBy(_._2).unzip._1) + new Table( + columns = if (ascending) columnsSorted else columnsSorted.reverse, + rows = if (ascending) rowsSorted else rowsSorted.map(_.reverse), + name = name + ) + } + + /** + * Sort table columns in alphabetical order + */ + def sortColumnsAsc: Table = sortColumns(ascending = true) + + /** + * Sort table columns in inverse alphabetical order + */ + def sortColumnsDesc: Table = sortColumns(ascending = false) + /** * Pretty print table * @@ -95,9 +128,8 @@ case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String columnAlignments: Map[String, Alignment] = Map.empty, defaultColumnAlignment: Alignment = Alignment.Left ): String = { - val rowVals = rows.map(_.productIterator.map(v => Option(v).map(_.toString).getOrElse("")).toSeq) val columnSizes = columns.map(c => math.max(1, c.length)).toArray - val cellSizes = rowVals.map(_.map(_.length).toArray).foldLeft(columnSizes)(_ + _) + val cellSizes = rows.map(_.map(_.length).toArray).foldLeft(columnSizes)(Table.monoid.plus) val bracket = formatRow(Seq.fill(cellSizes.length)(""), cellSizes, _ => Alignment.Left, sep = "+", fill = "-") val rowWidth = bracket.length - 4 val cleanBracket = formatRow(Seq(""), Seq(rowWidth), _ => Alignment.Left, sep = "+", fill = "-") @@ -107,7 +139,7 @@ case class Table[T <: Product](columns: Seq[String], rows: Seq[T], name: String } val alignment: String => Alignment = columnAlignments.getOrElse(_, defaultColumnAlignment) val columnsHeader = formatRow(columns, cellSizes, alignment) - val formattedRows = rowVals.map(formatRow(_, cellSizes, alignment)) + val formattedRows = rows.map(formatRow(_, cellSizes, alignment)) (maybeName ++ Seq(cleanBracket, columnsHeader, bracket) ++ formattedRows :+ bracket).mkString("\n") } diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala index 4faa78a903..2b01126a70 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -51,7 +51,7 @@ class TableTest extends FlatSpec with TestCommon { Transaction(3, 4.75, "Caltrain", "Pending") ) - Spec[Table[_]] should "error on missing columns" in { + Spec[Table] should "error on missing columns" in { intercept[IllegalArgumentException] { Table(columns = Seq.empty, rows = transactions) }.getMessage shouldBe "requirement failed: columns cannot be empty" @@ -80,6 +80,26 @@ class TableTest extends FlatSpec with TestCommon { val table = Table(columns = columns, rows = transactions) table.prettyString() shouldBe table.toString } + it should "sort columns in ascending order" in { + Table(columns = columns, rows = transactions).sortColumnsAsc.prettyString() shouldBe + """|+----------------------------------------+ + || amount | date | source | status | + |+--------+------+--------------+---------+ + || 4.95 | 1 | Cafe Venetia | Success | + || 12.65 | 2 | Sprout | Success | + || 4.75 | 3 | Caltrain | Pending | + |+--------+------+--------------+---------+""".stripMargin + } + it should "sort columns in descending order" in { + Table(columns = columns, rows = transactions).sortColumnsDesc.prettyString() shouldBe + """|+----------------------------------------+ + || status | source | date | amount | + |+---------+--------------+------+--------+ + || Success | Cafe Venetia | 1 | 4.95 | + || Success | Sprout | 2 | 12.65 | + || Pending | Caltrain | 3 | 4.75 | + |+---------+--------------+------+--------+""".stripMargin + } it should "pretty print a table with a name" in { Table(columns = columns, rows = transactions, name = "Transactions").prettyString() shouldBe """|+----------------------------------------+ From 033cce9c0d47aa4ae210cdaf10f5ee3cc5b8379e Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 07:53:57 -0700 Subject: [PATCH 08/19] Added a simple perf test --- .../scala/com/salesforce/op/utils/table/TableTest.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala index 2b01126a70..fc037e6e43 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -184,5 +184,14 @@ class TableTest extends FlatSpec with TestCommon { || 1 | |+---+""".stripMargin } + it should "pretty print in timely fashion" in { + val columns = Seq("c1", "c2", "c3", "c4", "c5") + val rows = (0 until 100000).map(i => (i, i + 1, i - 1, i + i, i * i)) + Table(columns, rows).prettyString() // warmup + val start = System.currentTimeMillis() + Table(columns, rows).prettyString() + val elapsed = System.currentTimeMillis() - start + elapsed should be < 5000L + } } From 24d8a71b3f5fb646b2c5d62bf5a6a4e2789a7474 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 07:55:09 -0700 Subject: [PATCH 09/19] no need really --- .../scala/com/salesforce/op/utils/table/TableTest.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala index fc037e6e43..2b01126a70 100644 --- a/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala +++ b/utils/src/test/scala/com/salesforce/op/utils/table/TableTest.scala @@ -184,14 +184,5 @@ class TableTest extends FlatSpec with TestCommon { || 1 | |+---+""".stripMargin } - it should "pretty print in timely fashion" in { - val columns = Seq("c1", "c2", "c3", "c4", "c5") - val rows = (0 until 100000).map(i => (i, i + 1, i - 1, i + i, i * i)) - Table(columns, rows).prettyString() // warmup - val start = System.currentTimeMillis() - Table(columns, rows).prettyString() - val elapsed = System.currentTimeMillis() - start - elapsed should be < 5000L - } } From 35dcd14e5809a68fa43c7bd90492763d89f3846c Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 17:39:09 -0700 Subject: [PATCH 10/19] A handful of model insight helper methods --- .../com/salesforce/op/ModelInsights.scala | 122 +++++++++++++++++- .../com/salesforce/op/OpWorkflowModel.scala | 22 +++- .../com/salesforce/op/ModelInsightsTest.scala | 37 +++++- 3 files changed, 170 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index a913e1a0c2..92fc3d1a44 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -31,23 +31,33 @@ package com.salesforce.op +import com.salesforce.op.evaluators._ import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types.{OPVector, RealNN} +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{DecisionTree, LogisticRegression, NaiveBayes, RandomForest} import com.salesforce.op.stages.impl.preparators._ -import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, SelectedModel} +import com.salesforce.op.stages.impl.regression.RegressionModelsToTry +import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.{DecisionTreeRegression, GBTRegression, LinearRegression, RandomForestRegression} +import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, ModelSelectorBaseNames, SelectedModel} +import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames._ import com.salesforce.op.stages.{OPStage, OpPipelineStageParams, OpPipelineStageParamsNames} +import com.salesforce.op.utils.json.JsonUtils import com.salesforce.op.utils.spark.OpVectorMetadata import com.salesforce.op.utils.spark.RichMetadata._ +import enumeratum.EnumEntry import org.apache.spark.ml.classification._ import org.apache.spark.ml.regression._ import org.apache.spark.ml.{Model, PipelineStage, Transformer} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel +import org.apache.spark.sql.types.Metadata import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization.{write, writePretty} import org.slf4j.LoggerFactory -import scala.util.Try +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} /** * Summary of all model insights @@ -68,10 +78,118 @@ case class ModelInsights stageInfo: Map[String, Any] ) { + /** + * Best model UID + */ + def bestModelUid: String = selectedModelInfo(BestModelUid).toString + + /** + * Best model name + */ + def bestModelName: String = selectedModelInfo(BestModelName).toString + + /** + * Best model type, i.e. LogisticRegression, RandomForest etc. + */ + def bestModelType: EnumEntry = { + classificationModelTypeOfUID.orElse(regressionModelTypeOfUID).lift(bestModelUid).getOrElse( + throw new Exception(s"Unsupported model type for best model '$bestModelUid'")) + } + + /** + * Best model validation results computed during Cross Validation or Train Validation Split + */ + def bestModelValidationResults: Map[String, String] = validationResults(bestModelName) + + /** + * Validation results computed during Cross Validation or Train Validation Split + * + * @return validation results keyed by model name + */ + def validationResults: Map[String, Map[String, String]] = { + val res = for { + results <- getMap[String, Any](selectedModelInfo, TrainValSplitResults).recoverWith { + case e => getMap[String, Any](selectedModelInfo, CrossValResults) + } + } yield results.keys.map(k => k -> getMap[String, String](results, k).getOrElse(Map.empty)) + res match { + case Failure(e) => throw new Exception(s"Failed to extract validation results", e) + case Success(ok) => ok.toMap + } + } + + /** + * Train set evaluation metrics + */ + def trainEvaluationMetrics: EvaluationMetrics = evaluationMetrics(TrainingEval) + + /** + * Test set evaluation metrics (if any) + */ + def testEvaluationMetrics: Option[EvaluationMetrics] = { + selectedModelInfo.get(HoldOutEval).map(_ => evaluationMetrics(HoldOutEval)) + } + + /** + * Serialize to json string + * + * @param pretty should pretty format + * @return json string + */ def toJson(pretty: Boolean = true): String = { implicit val formats = DefaultFormats if (pretty) writePretty(this) else write(this) } + + private def classificationModelTypeOfUID: PartialFunction[String, ClassificationModelsToTry] = { + case uid if uid.startsWith("logreg") => LogisticRegression + case uid if uid.startsWith("rfc") => RandomForest + case uid if uid.startsWith("dtc") => DecisionTree + case uid if uid.startsWith("nb") => NaiveBayes + } + private def regressionModelTypeOfUID: PartialFunction[String, RegressionModelsToTry] = { + case uid if uid.startsWith("linReg") => LinearRegression + case uid if uid.startsWith("rfr") => RandomForestRegression + case uid if uid.startsWith("dtr") => DecisionTreeRegression + case uid if uid.startsWith("gbtr") => GBTRegression + } + private def evaluationMetrics(metricsName: String): EvaluationMetrics = { + val res = for { + metricsMap <- getMap[String, Double](selectedModelInfo, metricsName) + evalMetrics <- Try(toEvaluationMetrics(metricsMap)) + } yield evalMetrics + res match { + case Failure(e) => throw new Exception(s"Failed to extract '$metricsName' metrics", e) + case Success(ok) => ok + } + } + private def getMap[K, V](m: Map[String, Any], name: String): Try[Map[K, V]] = Try { + m(name) match { + case m: Map[String, Any]@unchecked => m("map").asInstanceOf[Map[K, V]] + case m: Metadata => m.underlyingMap.asInstanceOf[Map[K, V]] + } + } + + private val MetricName = "\\((.*)\\)\\_(.*)".r + + private def toEvaluationMetrics(metrics: Map[String, Double]): EvaluationMetrics = { + import OpEvaluatorNames._ + val metricsType = metrics.keys.headOption match { + case Some(MetricName(t, _)) if Set(binary, multi, regression).contains(t) => t + case v => throw new Exception(s"Invalid model metric '$v'") + } + def parse[T <: EvaluationMetrics : ClassTag] = { + val vals = metrics.map { case (MetricName(_, name), value) => name -> value } + val valsJson = JsonUtils.toJsonString(vals) + JsonUtils.fromString[T](valsJson).get + } + metricsType match { + case `binary` => parse[BinaryClassificationMetrics] + case `multi` => parse[MultiClassificationMetrics] + case `regression` => parse[RegressionMetrics] + case t => throw new Exception(s"Unsupported metrics type '$t'") + } + } } /** diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index a40f2ae8f8..67aec1d6b1 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -35,6 +35,7 @@ import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase} import com.salesforce.op.features.types.FeatureType import com.salesforce.op.features.{FeatureLike, OPFeature} import com.salesforce.op.readers.DataFrameFieldNames._ +import com.salesforce.op.stages.impl.selector.StageParamNames import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ @@ -165,7 +166,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams } /** - * Pulls all summary metadata off of transformers + * Pulls all summary metadata of transformers and puts them in json * * @return json summary */ @@ -177,12 +178,27 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams ) /** - * Pulls all summary metadata off of transformers and puts them in a pretty json string + * Pulls all summary metadata of transformers and puts them into json string * - * @return string summary + * @return json string summary */ def summary(): String = pretty(render(summaryJson())) + /** + * Pulls all summary metadata of transformers and puts them into compact print friendly string + * + * @return compact print friendly string + */ + def summaryPretty(): String = { + val prediction = resultFeatures.find(_.name == StageParamNames.outputParam1Name).orElse( + stages.map(_.getOutput()).find(_.name == StageParamNames.outputParam1Name) + ).getOrElse( + throw new Exception("No prediction feature is defined") + ) + val insights = modelInsights(prediction) + ??? + } + /** * Save this model to a path * diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index a523152f30..b93c4cba80 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -31,14 +31,16 @@ package com.salesforce.op +import com.salesforce.op.evaluators.BinaryClassificationMetrics import com.salesforce.op.features.Feature -import com.salesforce.op.features.types.{FeatureTypeDefaults, PickList, Real, RealNN} +import com.salesforce.op.features.types.{PickList, Real, RealNN} import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.LogisticRegression import com.salesforce.op.stages.impl.preparators._ import com.salesforce.op.stages.impl.regression.RegressionModelSelector import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.LinearRegression import com.salesforce.op.stages.impl.selector.SelectedModel +import com.salesforce.op.stages.impl.tuning.DataSplitter import com.salesforce.op.test.PassengerSparkFixtureTest import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import org.junit.runner.RunWith @@ -55,11 +57,11 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { private val descrVec = description.vectorize(10, false, 1, true) private val features = Seq(density, age, generVec, weight, descrVec).transmogrify() private val label = survived.occurs() - private val checked = label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false, - checkSample = 1.0) + private val checked = + label.sanityCheck(features, removeBadFeatures = true, removeFeatureGroup = false, checkSample = 1.0) val (pred, rawPred, prob) = BinaryClassificationModelSelector - .withCrossValidation(seed = 42, splitter = None) + .withCrossValidation(seed = 42, splitter = Option(DataSplitter(seed = 42, reserveTestFraction = 0.1))) .setModelsToTry(LogisticRegression) .setLogisticRegressionRegParam(0.01, 0.1) .setInput(label, checked) @@ -119,7 +121,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.label.rawFeatureName shouldBe Seq(survived.name) insights.label.rawFeatureType shouldBe Seq(survived.typeName) insights.label.stagesApplied.size shouldBe 1 - insights.label.sampleSize shouldBe Some(6.0) + insights.label.sampleSize shouldBe Some(4.0) insights.features.size shouldBe 5 insights.features.map(_.featureName).toSet shouldEqual rawNames ageInsights.derivedFeatures.size shouldBe 2 @@ -170,7 +172,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.label.rawFeatureName shouldBe Seq(survived.name) insights.label.rawFeatureType shouldBe Seq(survived.typeName) insights.label.stagesApplied.size shouldBe 1 - insights.label.sampleSize shouldBe Some(6.0) + insights.label.sampleSize shouldBe Some(4.0) insights.features.size shouldBe 5 insights.features.map(_.featureName).toSet shouldEqual rawNames ageInsights.derivedFeatures.size shouldBe 2 @@ -237,6 +239,29 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { lin.head.size shouldBe OpVectorMetadata("", checked.originStage.getMetadata()).columns.length } + it should "return best model information" in { + val insights = workflowModel.modelInsights(prob) + insights.bestModelUid should startWith("logreg_") + insights.bestModelName should startWith("logreg_") + insights.bestModelType shouldBe LogisticRegression + val bestModelValidationResults = insights.bestModelValidationResults + bestModelValidationResults.size shouldBe 15 + println(bestModelValidationResults) + bestModelValidationResults.get("area under PR") shouldBe Some("0.0") + val validationResults = insights.validationResults + validationResults.size shouldBe 2 + validationResults.get(insights.bestModelName) shouldBe Some(bestModelValidationResults) + } + + it should "return test/train evaluation metrics" in { + val insights = workflowModel.modelInsights(prob) + insights.trainEvaluationMetrics shouldBe + BinaryClassificationMetrics(1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 5.0, 0.0, 0.0) + insights.testEvaluationMetrics shouldBe Some( + BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.5, 0.75, 0.5, 0.0, 1.0, 0.0, 1.0) + ) + } + it should "correctly serialize and deserialize from json" in { val insights = workflowModel.modelInsights(prob) ModelInsights.fromJson(insights.toJson()) match { From 28bffca702320f6c3d9e98786570ee8cc39003d2 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Thu, 21 Jun 2018 17:49:40 -0700 Subject: [PATCH 11/19] cleanup --- core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala | 4 +++- core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 67aec1d6b1..d47ccce06b 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -196,7 +196,9 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams throw new Exception("No prediction feature is defined") ) val insights = modelInsights(prediction) - ??? + + // TODO + throw new NotImplementedError } /** diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index b93c4cba80..c8f4c20ad1 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -246,7 +246,6 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.bestModelType shouldBe LogisticRegression val bestModelValidationResults = insights.bestModelValidationResults bestModelValidationResults.size shouldBe 15 - println(bestModelValidationResults) bestModelValidationResults.get("area under PR") shouldBe Some("0.0") val validationResults = insights.validationResults validationResults.size shouldBe 2 @@ -256,7 +255,7 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { it should "return test/train evaluation metrics" in { val insights = workflowModel.modelInsights(prob) insights.trainEvaluationMetrics shouldBe - BinaryClassificationMetrics(1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 5.0, 0.0, 0.0) + BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0) insights.testEvaluationMetrics shouldBe Some( BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.5, 0.75, 0.5, 0.0, 1.0, 0.0, 1.0) ) From 64c91c53fe1dbc0f258b9012962ed85934e657ea Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 22 Jun 2018 08:53:25 -0700 Subject: [PATCH 12/19] some renames --- .../com/salesforce/op/ModelInsights.scala | 61 ++++++++++++------- .../com/salesforce/op/OpWorkflowModel.scala | 53 ++++++++++++---- .../com/salesforce/op/ModelInsightsTest.scala | 15 ++--- 3 files changed, 88 insertions(+), 41 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index 92fc3d1a44..463041f2ff 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -39,13 +39,13 @@ import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{D import com.salesforce.op.stages.impl.preparators._ import com.salesforce.op.stages.impl.regression.RegressionModelsToTry import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.{DecisionTreeRegression, GBTRegression, LinearRegression, RandomForestRegression} -import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, ModelSelectorBaseNames, SelectedModel} import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames._ +import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, SelectedModel} import com.salesforce.op.stages.{OPStage, OpPipelineStageParams, OpPipelineStageParamsNames} import com.salesforce.op.utils.json.JsonUtils import com.salesforce.op.utils.spark.OpVectorMetadata import com.salesforce.op.utils.spark.RichMetadata._ -import enumeratum.EnumEntry +import enumeratum._ import org.apache.spark.ml.classification._ import org.apache.spark.ml.regression._ import org.apache.spark.ml.{Model, PipelineStage, Transformer} @@ -79,30 +79,42 @@ case class ModelInsights ) { /** - * Best model UID + * Selected model UID */ - def bestModelUid: String = selectedModelInfo(BestModelUid).toString + def selectedModelUID: String = selectedModelInfo(BestModelUid).toString /** - * Best model name + * Selected model name */ - def bestModelName: String = selectedModelInfo(BestModelName).toString + def selectedModelName: String = selectedModelInfo(BestModelName).toString /** - * Best model type, i.e. LogisticRegression, RandomForest etc. + * Selected model type, i.e. LogisticRegression, RandomForest etc. */ - def bestModelType: EnumEntry = { - classificationModelTypeOfUID.orElse(regressionModelTypeOfUID).lift(bestModelUid).getOrElse( - throw new Exception(s"Unsupported model type for best model '$bestModelUid'")) + def selectedModelType: EnumEntry = { + classificationModelTypeOfUID.orElse(regressionModelTypeOfUID).lift(selectedModelUID).getOrElse( + throw new Exception(s"Unsupported model type for best model '$selectedModelUID'")) } /** - * Best model validation results computed during Cross Validation or Train Validation Split + * Selected model validation results computed during Cross Validation or Train Validation Split */ - def bestModelValidationResults: Map[String, String] = validationResults(bestModelName) + def selectedModelValidationResults: Map[String, String] = validationResults(selectedModelName) /** - * Validation results computed during Cross Validation or Train Validation Split + * Train set evaluation metrics for selected model + */ + def selectedModelTrainEvalMetrics: EvaluationMetrics = evaluationMetrics(TrainingEval) + + /** + * Test set evaluation metrics (if any) for selected model + */ + def selectedModelTestEvalMetrics: Option[EvaluationMetrics] = { + selectedModelInfo.get(HoldOutEval).map(_ => evaluationMetrics(HoldOutEval)) + } + + /** + * Validation results for all models computed during Cross Validation or Train Validation Split * * @return validation results keyed by model name */ @@ -119,15 +131,13 @@ case class ModelInsights } /** - * Train set evaluation metrics - */ - def trainEvaluationMetrics: EvaluationMetrics = evaluationMetrics(TrainingEval) - - /** - * Test set evaluation metrics (if any) + * Problem type, i.e. Binary Classification, Multi Classification or Regression */ - def testEvaluationMetrics: Option[EvaluationMetrics] = { - selectedModelInfo.get(HoldOutEval).map(_ => evaluationMetrics(HoldOutEval)) + def problemType: ProblemType = selectedModelTrainEvalMetrics match { + case _: BinaryClassificationMetrics => ProblemType.BinaryClassification + case _: MultiClassificationMetrics => ProblemType.MultiClassification + case _: RegressionMetrics => ProblemType.Regression + case _ => ProblemType.Unknown } /** @@ -192,6 +202,15 @@ case class ModelInsights } } +sealed trait ProblemType extends EnumEntry with Serializable + object ProblemType extends Enum[ProblemType] { + val values = findValues + case object BinaryClassification extends ProblemType + case object MultiClassification extends ProblemType + case object Regression extends ProblemType + case object Unknown extends ProblemType +} + /** * Summary information about label used in model creation (all fields will be empty if no label is found) * diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index d47ccce06b..00f263a84f 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -35,14 +35,13 @@ import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase} import com.salesforce.op.features.types.FeatureType import com.salesforce.op.features.{FeatureLike, OPFeature} import com.salesforce.op.readers.DataFrameFieldNames._ -import com.salesforce.op.stages.impl.selector.StageParamNames import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ -import org.apache.spark.ml._ -import org.apache.spark.rdd.RDD +import com.salesforce.op.utils.table.Alignment._ +import com.salesforce.op.utils.table.Table import org.apache.spark.sql.types.Metadata -import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.{DataFrame, SparkSession} import org.json4s.JValue import org.json4s.JsonAST.{JField, JObject} import org.json4s.jackson.JsonMethods.{pretty, render} @@ -190,15 +189,43 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams * @return compact print friendly string */ def summaryPretty(): String = { - val prediction = resultFeatures.find(_.name == StageParamNames.outputParam1Name).orElse( - stages.map(_.getOutput()).find(_.name == StageParamNames.outputParam1Name) - ).getOrElse( - throw new Exception("No prediction feature is defined") - ) - val insights = modelInsights(prediction) - - // TODO - throw new NotImplementedError + val response = resultFeatures.find(_.isResponse).getOrElse(throw new Exception("No response feature is defined")) + val insights = modelInsights(response) + val summary = new ArrayBuffer[String]() + + // Selected model information + summary += { + val bestModelType = insights.selectedModelType + val name = s"Selected model - $bestModelType" + val validationResults = insights.selectedModelValidationResults.toSeq ++ Seq( + "name" -> insights.selectedModelName, + "uid" -> insights.selectedModelUID, + "modelType" -> insights.selectedModelType + ) + val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) + table.prettyString() + } + + // Model evaluation metrics + summary += { + val name = "Model Evaluation Metrics" + val trainEvaluationMetrics = insights.selectedModelTrainEvalMetrics + val testEvaluationMetrics = insights.selectedModelTestEvalMetrics + val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") + val trainMetrics = trainEvaluationMetrics.toMap.map { case (k, v) => k -> v.toString }.toSeq.sortBy(_._1) + val table = testEvaluationMetrics match { + case Some(testMetrics) => + val testMetricsMap = testMetrics.toMap + val rows = trainMetrics.map { case (k, v) => (k, v.toString, testMetricsMap(k).toString) } + Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) + case None => + Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) + } + table.prettyString(columnAlignments = Map(holdOutCol -> Right, trainingCol -> Right)) + } + + + summary.mkString("\n\n") } /** diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index c8f4c20ad1..d7a36f6754 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -241,22 +241,23 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { it should "return best model information" in { val insights = workflowModel.modelInsights(prob) - insights.bestModelUid should startWith("logreg_") - insights.bestModelName should startWith("logreg_") - insights.bestModelType shouldBe LogisticRegression - val bestModelValidationResults = insights.bestModelValidationResults + insights.selectedModelUID should startWith("logreg_") + insights.selectedModelName should startWith("logreg_") + insights.selectedModelType shouldBe LogisticRegression + val bestModelValidationResults = insights.selectedModelValidationResults bestModelValidationResults.size shouldBe 15 bestModelValidationResults.get("area under PR") shouldBe Some("0.0") val validationResults = insights.validationResults validationResults.size shouldBe 2 - validationResults.get(insights.bestModelName) shouldBe Some(bestModelValidationResults) + validationResults.get(insights.selectedModelName) shouldBe Some(bestModelValidationResults) } it should "return test/train evaluation metrics" in { val insights = workflowModel.modelInsights(prob) - insights.trainEvaluationMetrics shouldBe + insights.problemType shouldBe ProblemType.BinaryClassification + insights.selectedModelTrainEvalMetrics shouldBe BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0) - insights.testEvaluationMetrics shouldBe Some( + insights.selectedModelTestEvalMetrics shouldBe Some( BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.5, 0.75, 0.5, 0.0, 1.0, 0.0, 1.0) ) } From a3680c8874209d40351c1b855c8a0c2c12796ec2 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 22 Jun 2018 11:17:26 -0700 Subject: [PATCH 13/19] some mroe stuff --- .../com/salesforce/op/ModelInsights.scala | 79 +++++++++++++++---- .../com/salesforce/op/OpWorkflowModel.scala | 42 +++++++++- .../salesforce/op/evaluators/Evaluators.scala | 35 ++++---- .../op/evaluators/OpEvaluatorBase.scala | 68 +++++++++------- .../op/evaluators/OpRegressionEvaluator.scala | 10 --- .../classification/SelectorClassifiers.scala | 3 +- .../salesforce/op/stages/impl/package.scala | 43 ++++++++++ .../impl/regression/SelectorRegressors.scala | 3 +- .../com/salesforce/op/ModelInsightsTest.scala | 12 ++- .../com/salesforce/op/OpWorkflowTest.scala | 5 ++ 10 files changed, 227 insertions(+), 73 deletions(-) create mode 100644 core/src/main/scala/com/salesforce/op/stages/impl/package.scala diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index 463041f2ff..726fe5e8b0 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -34,6 +34,7 @@ package com.salesforce.op import com.salesforce.op.evaluators._ import com.salesforce.op.features.FeatureLike import com.salesforce.op.features.types.{OPVector, RealNN} +import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{DecisionTree, LogisticRegression, NaiveBayes, RandomForest} import com.salesforce.op.stages.impl.preparators._ @@ -91,10 +92,7 @@ case class ModelInsights /** * Selected model type, i.e. LogisticRegression, RandomForest etc. */ - def selectedModelType: EnumEntry = { - classificationModelTypeOfUID.orElse(regressionModelTypeOfUID).lift(selectedModelUID).getOrElse( - throw new Exception(s"Unsupported model type for best model '$selectedModelUID'")) - } + def selectedModelType: ModelsToTry = modelType(selectedModelName).get /** * Selected model validation results computed during Cross Validation or Train Validation Split @@ -130,6 +128,47 @@ case class ModelInsights } } + /** + * Validation results for a specified model type computed during Cross Validation or Train Validation Split + * + * @return validation results keyed by model name + */ + def validationResults(mType: ModelsToTry): Map[String, Map[String, String]] = { + validationResults.filter { case (modelName, _) => modelType(modelName).toOption.contains(mType) } + } + + /** + * All validated model types + */ + def validatedModelTypes: Set[ModelsToTry] = + validationResults.keys.flatMap(modelName => modelType(modelName).toOption).toSet + + /** + * Validation type, i.e TrainValidationSplit, CrossValidation + */ + def validationType: ValidationType = { + if (getMap[String, Any](selectedModelInfo, TrainValSplitResults).isSuccess) ValidationType.TrainValidationSplit + else if (getMap[String, Any](selectedModelInfo, CrossValResults).isSuccess) ValidationType.CrossValidation + else throw new Exception(s"Failed to determine validation type") + } + + /** + * Evaluation metric type, i.e. AuPR, AuROC, F1 etc. + */ + def evaluationMetricType: EnumEntry with EvalMetric = { + val knownEvalMetrics = { + (BinaryClassEvalMetrics.values ++ MultiClassEvalMetrics.values ++ RegressionEvalMetrics.values) + .map(m => m.humanFriendlyName -> m).toMap + } + val evalMetrics = validationResults.flatMap(_._2.keys).flatMap(knownEvalMetrics.get).toSet.toList + evalMetrics match { + case evalMetric :: Nil => evalMetric + case Nil => throw new Exception("Unable to determine evaluation metric type: no metrics were found") + case metrics => throw new Exception( + s"Unable to determine evaluation metric type since: multiple metrics were found - " + metrics.mkString(",")) + } + } + /** * Problem type, i.e. Binary Classification, Multi Classification or Regression */ @@ -151,17 +190,22 @@ case class ModelInsights if (pretty) writePretty(this) else write(this) } - private def classificationModelTypeOfUID: PartialFunction[String, ClassificationModelsToTry] = { - case uid if uid.startsWith("logreg") => LogisticRegression - case uid if uid.startsWith("rfc") => RandomForest - case uid if uid.startsWith("dtc") => DecisionTree - case uid if uid.startsWith("nb") => NaiveBayes + private def modelType(modelName: String): Try[ModelsToTry] = Try { + classificationModelType.orElse(regressionModelType).lift(modelName).getOrElse( + throw new Exception(s"Unsupported model type for best model '$modelName'")) + } + + private def classificationModelType: PartialFunction[String, ClassificationModelsToTry] = { + case v if v.startsWith("logreg") => LogisticRegression + case v if v.startsWith("rfc") => RandomForest + case v if v.startsWith("dtc") => DecisionTree + case v if v.startsWith("nb") => NaiveBayes } - private def regressionModelTypeOfUID: PartialFunction[String, RegressionModelsToTry] = { - case uid if uid.startsWith("linReg") => LinearRegression - case uid if uid.startsWith("rfr") => RandomForestRegression - case uid if uid.startsWith("dtr") => DecisionTreeRegression - case uid if uid.startsWith("gbtr") => GBTRegression + private def regressionModelType: PartialFunction[String, RegressionModelsToTry] = { + case v if v.startsWith("linReg") => LinearRegression + case v if v.startsWith("rfr") => RandomForestRegression + case v if v.startsWith("dtr") => DecisionTreeRegression + case v if v.startsWith("gbtr") => GBTRegression } private def evaluationMetrics(metricsName: String): EvaluationMetrics = { val res = for { @@ -211,6 +255,13 @@ sealed trait ProblemType extends EnumEntry with Serializable case object Unknown extends ProblemType } +sealed abstract class ValidationType(val humanFriendlyName: String) extends EnumEntry with Serializable +object ValidationType extends Enum[ValidationType] { + val values = findValues + case object CrossValidation extends ValidationType("Cross Validation") + case object TrainValidationSplit extends ValidationType("Train Validation Split") +} + /** * Summary information about label used in model creation (all fields will be empty if no label is found) * diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 00f263a84f..5bdfb16784 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -48,6 +48,7 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import scala.util.Try /** @@ -186,13 +187,49 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams /** * Pulls all summary metadata of transformers and puts them into compact print friendly string * - * @return compact print friendly string + * @return a compact print friendly string */ def summaryPretty(): String = { val response = resultFeatures.find(_.isResponse).getOrElse(throw new Exception("No response feature is defined")) val insights = modelInsights(response) val summary = new ArrayBuffer[String]() + // Validation results + summary += { + val validatedModelTypes = insights.validatedModelTypes + val validationType = insights.validationType.humanFriendlyName + val evalMetric = insights.evaluationMetricType.humanFriendlyName + "Evaluated %s model%s using %s and %s metric.".format( + validatedModelTypes.mkString(", "), + if (validatedModelTypes.size > 1) "s" else "", + validationType, // TODO add number of folds or train/split ratio if possible + evalMetric + ) + } + summary += { + val modelEvalRes = for { + modelType <- insights.validatedModelTypes + modelValidationResults = insights.validationResults(modelType) + evalMetric = insights.evaluationMetricType.humanFriendlyName + } yield { + val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => + metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) + } + val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) + val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) + + "Evaluated %d %s model%s with %s metric between [%s, %s].".format( + modelValidationResults.size, + modelType, + if (modelValidationResults.size > 1) "s" else "", + evalMetric, + minMetricValue, + maxMetricValue + ) + } + modelEvalRes.mkString("\n") + } + // Selected model information summary += { val bestModelType = insights.selectedModelType @@ -224,8 +261,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams table.prettyString(columnAlignments = Map(holdOutCol -> Right, trainingCol -> Right)) } - - summary.mkString("\n\n") + summary.mkString("\n") } /** diff --git a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala index 3645b45480..2e6cbed237 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala @@ -57,7 +57,8 @@ object Evaluators { * Area under ROC */ def auROC(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.auROC, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = BinaryClassEvalMetrics.AuROC.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuROC, dataset) } @@ -66,7 +67,7 @@ object Evaluators { * Area under Precision/Recall curve */ def auPR(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.auPR, isLargerBetter = true) { + new OpBinaryClassificationEvaluator(name = BinaryClassEvalMetrics.AuPR.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getBinaryEvaluatorMetric(BinaryClassEvalMetrics.AuPR, dataset) } @@ -75,7 +76,8 @@ object Evaluators { * Precision */ def precision(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.precision, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).precision(1.0) @@ -86,7 +88,7 @@ object Evaluators { * Recall */ def recall(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.recall, isLargerBetter = true) { + new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).recall(1.0) @@ -97,7 +99,7 @@ object Evaluators { * F1 score */ def f1(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.f1, isLargerBetter = true) { + new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics( @@ -109,7 +111,8 @@ object Evaluators { * Prediction error */ def error(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = OpMetricsNames.error, isLargerBetter = false) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = 1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset) } @@ -162,7 +165,8 @@ object Evaluators { * Weighted Precision */ def precision(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.precision, isLargerBetter = true) { + new OpMultiClassificationEvaluator( + name = MultiClassEvalMetrics.Precision.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.Precision, dataset) } @@ -171,7 +175,7 @@ object Evaluators { * Weighted Recall */ def recall(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.recall, isLargerBetter = true) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.Recall, dataset) } @@ -180,7 +184,7 @@ object Evaluators { * F1 Score */ def f1(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.f1, isLargerBetter = true) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.F1.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getMultiEvaluatorMetric(MultiClassEvalMetrics.F1, dataset) } @@ -189,7 +193,7 @@ object Evaluators { * Prediction Error */ def error(): OpMultiClassificationEvaluator = - new OpMultiClassificationEvaluator(name = OpMetricsNames.error, isLargerBetter = false) { + new OpMultiClassificationEvaluator(name = MultiClassEvalMetrics.Error.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = 1.0 - getMultiEvaluatorMetric(MultiClassEvalMetrics.Error, dataset) } @@ -252,7 +256,8 @@ object Evaluators { * Mean Squared Error */ def mse(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.meanSquaredError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.MeanSquaredError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.MeanSquaredError, dataset) } @@ -261,7 +266,8 @@ object Evaluators { * Mean Absolute Error */ def mae(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.meanAbsoluteError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.MeanAbsoluteError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.MeanAbsoluteError, dataset) } @@ -270,7 +276,7 @@ object Evaluators { * R2 */ def r2(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.r2, isLargerBetter = true) { + new OpRegressionEvaluator(name = RegressionEvalMetrics.R2.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.R2, dataset) } @@ -279,7 +285,8 @@ object Evaluators { * Root Mean Squared Error */ def rmse(): OpRegressionEvaluator = - new OpRegressionEvaluator(name = OpMetricsNames.rootMeanSquaredError, isLargerBetter = false) { + new OpRegressionEvaluator( + name = RegressionEvalMetrics.RootMeanSquaredError.humanFriendlyName, isLargerBetter = false) { override def evaluate(dataset: Dataset[_]): Double = getRegEvaluatorMetric(RegressionEvalMetrics.RootMeanSquaredError, dataset) } diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala index 1ce904bccd..74d3ef85b7 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpEvaluatorBase.scala @@ -196,17 +196,37 @@ abstract class OpRegressionEvaluatorBase[T <: EvaluationMetrics] with OpHasLabelCol[RealNN] with OpHasPredictionCol[RealNN] +/** + * Eval metric + */ +trait EvalMetric extends Serializable { + /** + * Spark metric name + */ + def sparkEntryName: String + /** + * Human friendly metric name + */ + def humanFriendlyName: String +} -sealed abstract class ClassificationEvalMetric(val sparkEntryName: String) extends EnumEntry with Serializable +/** + * Classification Metrics + */ +sealed abstract class ClassificationEvalMetric +( + val sparkEntryName: String, + val humanFriendlyName: String +) extends EnumEntry with EvalMetric /** * Binary Classification Metrics */ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] { val values = findValues - case object AuROC extends ClassificationEvalMetric("areaUnderROC") - case object AuPR extends ClassificationEvalMetric("areaUnderPR") + case object AuROC extends ClassificationEvalMetric("areaUnderROC", "area under ROC") + case object AuPR extends ClassificationEvalMetric("areaUnderPR", "area under PR") } /** @@ -214,31 +234,27 @@ object BinaryClassEvalMetrics extends Enum[ClassificationEvalMetric] { */ object MultiClassEvalMetrics extends Enum[ClassificationEvalMetric] { val values = findValues - case object Precision extends ClassificationEvalMetric("weightedPrecision") - case object Recall extends ClassificationEvalMetric("weightedRecall") - case object F1 extends ClassificationEvalMetric("f1") - case object Error extends ClassificationEvalMetric("accuracy") + case object Precision extends ClassificationEvalMetric("weightedPrecision", "precision") + case object Recall extends ClassificationEvalMetric("weightedRecall", "recall") + case object F1 extends ClassificationEvalMetric("f1", "f1") + case object Error extends ClassificationEvalMetric("accuracy", "error") } /** - * Contains the names of metrics used in logging - */ -private[op] case object OpMetricsNames { - val rootMeanSquaredError = "root mean square error" - val meanSquaredError = "mean square error" - val meanAbsoluteError = "mean absolute error" - val r2 = "r2" - val auROC = "area under ROC" - val auPR = "area under PR" - val precision = "precision" - val recall = "recall" - val f1 = "f1" - val accuracy = "accuracy" - val error = "error" - val tp = "true positive" - val tn = "true negative" - val fp = "false positive" - val fn = "false negative" + * Regression Metrics + */ +sealed abstract class RegressionEvalMetric +( + val sparkEntryName: String, + val humanFriendlyName: String +) extends EnumEntry with EvalMetric + +object RegressionEvalMetrics extends Enum[RegressionEvalMetric] { + val values: Seq[RegressionEvalMetric] = findValues + case object RootMeanSquaredError extends RegressionEvalMetric("rmse", "root mean square error") + case object MeanSquaredError extends RegressionEvalMetric("mse", "mean square error") + case object R2 extends RegressionEvalMetric("r2", "r2") + case object MeanAbsoluteError extends RegressionEvalMetric("mae", "mean absolute error") } /** @@ -249,5 +265,3 @@ case object OpEvaluatorNames { val multi = "multiEval" val regression = "regEval" } - - diff --git a/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala b/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala index 5ed64f3a18..5a4e475e34 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/OpRegressionEvaluator.scala @@ -101,13 +101,3 @@ case class RegressionMetrics R2: Double, MeanAbsoluteError: Double ) extends EvaluationMetrics - -/* Regression Metrics */ -sealed abstract class RegressionEvalMetric(val sparkEntryName: String) extends EnumEntry with Serializable -object RegressionEvalMetrics extends Enum[RegressionEvalMetric] { - val values: Seq[RegressionEvalMetric] = findValues - case object RootMeanSquaredError extends RegressionEvalMetric("rmse") - case object MeanSquaredError extends RegressionEvalMetric("mse") - case object R2 extends RegressionEvalMetric("r2") - case object MeanAbsoluteError extends RegressionEvalMetric("mae") -} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala b/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala index 08d62be8f8..6485bdc95e 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/classification/SelectorClassifiers.scala @@ -31,6 +31,7 @@ package com.salesforce.op.stages.impl.classification +import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.classification.ProbabilisticClassifierType.ProbClassifier import com.salesforce.op.stages.impl.selector._ import org.apache.spark.ml.classification._ @@ -43,7 +44,7 @@ import scala.reflect.ClassTag /** * Enumeration of possible classification models in Model Selector */ -sealed trait ClassificationModelsToTry extends EnumEntry with Serializable +sealed trait ClassificationModelsToTry extends ModelsToTry object ClassificationModelsToTry extends Enum[ClassificationModelsToTry] { val values = findValues diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/package.scala b/core/src/main/scala/com/salesforce/op/stages/impl/package.scala new file mode 100644 index 0000000000..11db127f50 --- /dev/null +++ b/core/src/main/scala/com/salesforce/op/stages/impl/package.scala @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2017, Salesforce.com, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of Salesforce.com nor the names of its contributors may + * be used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ + +package com.salesforce.op.stages + +import enumeratum.EnumEntry + +package object impl { + + /** + * Enumeration of possible models in Model Selectors + */ + trait ModelsToTry extends EnumEntry with Serializable + +} diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala b/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala index 300aa121a9..84080e8dcc 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/regression/SelectorRegressors.scala @@ -31,6 +31,7 @@ package com.salesforce.op.stages.impl.regression +import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.regression.RegressorType._ import com.salesforce.op.stages.impl.selector._ import org.apache.spark.ml.param.{BooleanParam, Param, Params} @@ -44,7 +45,7 @@ import scala.reflect.ClassTag /** * Enumeration of possible regression models in Model Selector */ -sealed trait RegressionModelsToTry extends EnumEntry with Serializable +sealed trait RegressionModelsToTry extends ModelsToTry object RegressionModelsToTry extends Enum[RegressionModelsToTry] { val values = findValues diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index d7a36f6754..7c61dbbf4d 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -31,11 +31,11 @@ package com.salesforce.op -import com.salesforce.op.evaluators.BinaryClassificationMetrics +import com.salesforce.op.evaluators.{BinaryClassEvalMetrics, BinaryClassificationMetrics} import com.salesforce.op.features.Feature import com.salesforce.op.features.types.{PickList, Real, RealNN} import com.salesforce.op.stages.impl.classification.BinaryClassificationModelSelector -import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.LogisticRegression +import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{LogisticRegression, NaiveBayes} import com.salesforce.op.stages.impl.preparators._ import com.salesforce.op.stages.impl.regression.RegressionModelSelector import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.LinearRegression @@ -246,14 +246,20 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { insights.selectedModelType shouldBe LogisticRegression val bestModelValidationResults = insights.selectedModelValidationResults bestModelValidationResults.size shouldBe 15 - bestModelValidationResults.get("area under PR") shouldBe Some("0.0") + bestModelValidationResults.get(BinaryClassEvalMetrics.AuPR.humanFriendlyName) shouldBe Some("0.0") val validationResults = insights.validationResults validationResults.size shouldBe 2 validationResults.get(insights.selectedModelName) shouldBe Some(bestModelValidationResults) + insights.validationResults(LogisticRegression) shouldBe validationResults + insights.validationResults(NaiveBayes) shouldBe Map.empty } it should "return test/train evaluation metrics" in { val insights = workflowModel.modelInsights(prob) + insights.evaluationMetricType shouldBe BinaryClassEvalMetrics.AuPR + insights.validationType shouldBe ValidationType.CrossValidation + insights.validatedModelTypes shouldBe Set(LogisticRegression) + insights.problemType shouldBe ProblemType.BinaryClassification insights.selectedModelTrainEvalMetrics shouldBe BinaryClassificationMetrics(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0) diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index f889f4c35c..729807cd28 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -391,6 +391,11 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { summary.contains(""""regParam" : "0.01"""") shouldBe true summary.contains(ModelSelectorBaseNames.HoldOutEval) shouldBe true summary.contains(ModelSelectorBaseNames.TrainingEval) shouldBe true + + val prettySummary = fittedWorkflow.summaryPretty() + log.info(prettySummary) + prettySummary.contains("Selected model - LogisticRegression") shouldBe true + prettySummary.contains("Model Evaluation Metrics") shouldBe true } it should "be able to refit a workflow with calibrated probability" in { From ad6719ba600d9e793443b100f1361014511a3713 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 22 Jun 2018 11:18:21 -0700 Subject: [PATCH 14/19] todo --- core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 5bdfb16784..bd77f11c50 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -261,6 +261,8 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams table.prettyString(columnAlignments = Map(holdOutCol -> Right, trainingCol -> Right)) } + // TODO: Sanity checker results (if any) + summary.mkString("\n") } From f38d95068109d97e48bcabe153b1077dbcb31eec Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 22 Jun 2018 11:26:24 -0700 Subject: [PATCH 15/19] cleanup --- core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala | 2 +- core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index bd77f11c50..44aca2dc55 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -233,7 +233,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams // Selected model information summary += { val bestModelType = insights.selectedModelType - val name = s"Selected model - $bestModelType" + val name = s"Selected Model - $bestModelType" val validationResults = insights.selectedModelValidationResults.toSeq ++ Seq( "name" -> insights.selectedModelName, "uid" -> insights.selectedModelUID, diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index 729807cd28..79a04eaf9a 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -394,7 +394,8 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { val prettySummary = fittedWorkflow.summaryPretty() log.info(prettySummary) - prettySummary.contains("Selected model - LogisticRegression") shouldBe true + prettySummary.contains(s"Selected Model - $LogisticRegression") shouldBe true + prettySummary.contains("| area under PR | 0.25") shouldBe true prettySummary.contains("Model Evaluation Metrics") shouldBe true } From 6795a7632b03293d2a91dd2580fe73289358e452 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 22 Jun 2018 11:33:22 -0700 Subject: [PATCH 16/19] scalastyle --- .../main/scala/com/salesforce/op/evaluators/Evaluators.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala index 2e6cbed237..9d9b839055 100644 --- a/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala +++ b/core/src/main/scala/com/salesforce/op/evaluators/Evaluators.scala @@ -88,7 +88,8 @@ object Evaluators { * Recall */ def recall(): OpBinaryClassificationEvaluator = - new OpBinaryClassificationEvaluator(name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { + new OpBinaryClassificationEvaluator( + name = MultiClassEvalMetrics.Recall.humanFriendlyName, isLargerBetter = true) { override def evaluate(dataset: Dataset[_]): Double = { import dataset.sparkSession.implicits._ new MulticlassMetrics(dataset.select(getPredictionCol, getLabelCol).as[(Double, Double)].rdd).recall(1.0) From eb9f19d339028938530127e5a784b66918fd1d33 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Mon, 25 Jun 2018 19:33:20 -0700 Subject: [PATCH 17/19] added sanity checker results --- .../com/salesforce/op/OpWorkflowModel.scala | 193 ++++++++++++------ 1 file changed, 133 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 88d6d05966..31521f8630 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -35,20 +35,21 @@ import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase} import com.salesforce.op.features.types.FeatureType import com.salesforce.op.features.{FeatureLike, OPFeature} import com.salesforce.op.readers.DataFrameFieldNames._ +import com.salesforce.op.stages.impl.feature.TransmogrifierDefaults import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer} +import com.salesforce.op.utils.spark.OpVectorColumnMetadata import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ -import com.salesforce.op.utils.table.Alignment._ -import com.salesforce.op.utils.table.Table import com.salesforce.op.utils.stages.FitStagesUtil -import org.apache.spark.ml.Estimator +import com.salesforce.op.utils.table.Alignment._ +import com.salesforce.op.utils.table._ import org.apache.spark.sql.types.Metadata import org.apache.spark.sql.{DataFrame, SparkSession} import org.json4s.JValue import org.json4s.JsonAST.{JField, JObject} import org.json4s.jackson.JsonMethods.{pretty, render} -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.Try @@ -193,17 +194,27 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams def summary(): String = pretty(render(summaryJson())) /** - * Pulls all summary metadata of transformers and puts them into compact print friendly string + * High level model summary in a compact print friendly format containing: + * selected model info, model evaluation results and feature correlations/contributions/cramersV values. * - * @return a compact print friendly string + * @param topK top K of feature correlations/contributions/cramersV values + * @return high level model summary in a compact print friendly format */ - def summaryPretty(): String = { + def summaryPretty(topK: Int = 15): String = { val response = resultFeatures.find(_.isResponse).getOrElse(throw new Exception("No response feature is defined")) val insights = modelInsights(response) val summary = new ArrayBuffer[String]() + summary ++= validationResults(insights) + summary += selectedModelInfo(insights) + summary += modelEvaluationMetrics(insights) + summary ++= topKCorrelations(insights, topK) + summary += topKContributions(insights, topK) + summary ++= topKCramersV(insights, topK) + summary.mkString("\n") + } - // Validation results - summary += { + private def validationResults(insights: ModelInsights): Seq[String] = { + val evalSummary = { val validatedModelTypes = insights.validatedModelTypes val validationType = insights.validationType.humanFriendlyName val evalMetric = insights.evaluationMetricType.humanFriendlyName @@ -214,64 +225,126 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams evalMetric ) } - summary += { - val modelEvalRes = for { - modelType <- insights.validatedModelTypes - modelValidationResults = insights.validationResults(modelType) - evalMetric = insights.evaluationMetricType.humanFriendlyName - } yield { - val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => - metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) - } - val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) - val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) - - "Evaluated %d %s model%s with %s metric between [%s, %s].".format( - modelValidationResults.size, - modelType, - if (modelValidationResults.size > 1) "s" else "", - evalMetric, - minMetricValue, - maxMetricValue - ) + val modelEvalRes = for { + modelType <- insights.validatedModelTypes + modelValidationResults = insights.validationResults(modelType) + evalMetric = insights.evaluationMetricType.humanFriendlyName + } yield { + val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => + metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) } - modelEvalRes.mkString("\n") - } - - // Selected model information - summary += { - val bestModelType = insights.selectedModelType - val name = s"Selected Model - $bestModelType" - val validationResults = insights.selectedModelValidationResults.toSeq ++ Seq( - "name" -> insights.selectedModelName, - "uid" -> insights.selectedModelUID, - "modelType" -> insights.selectedModelType + val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) + val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) + + "Evaluated %d %s model%s with %s metric between [%s, %s].".format( + modelValidationResults.size, + modelType, + if (modelValidationResults.size > 1) "s" else "", + evalMetric, + minMetricValue, + maxMetricValue ) - val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) - table.prettyString() } + Seq(evalSummary, modelEvalRes.mkString("\n")) + } - // Model evaluation metrics - summary += { - val name = "Model Evaluation Metrics" - val trainEvaluationMetrics = insights.selectedModelTrainEvalMetrics - val testEvaluationMetrics = insights.selectedModelTestEvalMetrics - val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") - val trainMetrics = trainEvaluationMetrics.toMap.map { case (k, v) => k -> v.toString }.toSeq.sortBy(_._1) - val table = testEvaluationMetrics match { - case Some(testMetrics) => - val testMetricsMap = testMetrics.toMap - val rows = trainMetrics.map { case (k, v) => (k, v.toString, testMetricsMap(k).toString) } - Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) - case None => - Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) - } - table.prettyString(columnAlignments = Map(holdOutCol -> Right, trainingCol -> Right)) + private def selectedModelInfo(insights: ModelInsights): String = { + val bestModelType = insights.selectedModelType + val name = s"Selected Model - $bestModelType" + val validationResults = insights.selectedModelValidationResults.toSeq ++ Seq( + "name" -> insights.selectedModelName, + "uid" -> insights.selectedModelUID, + "modelType" -> insights.selectedModelType + ) + val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) + table.prettyString() + } + + private def modelEvaluationMetrics(insights: ModelInsights): String = { + def stringOf: PartialFunction[Any, String] = { + case s: Traversable[_] => s.map(_.toString).mkString("[",",","]") + case v: Any => v.toString } + val name = "Model Evaluation Metrics" + val trainEvaluationMetrics = insights.selectedModelTrainEvalMetrics + val testEvaluationMetrics = insights.selectedModelTestEvalMetrics + val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") + val trainMetrics = trainEvaluationMetrics.toMap.map { case (k, v) => k -> stringOf(v) }.toSeq.sortBy(_._1) + val table = testEvaluationMetrics match { + case Some(testMetrics) => + val testMetricsMap = testMetrics.toMap + val rows = trainMetrics.map { case (k, v) => (k, stringOf(v), stringOf(testMetricsMap(k))) } + Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) + case None => + Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) + } + table.prettyString() + } - // TODO: Sanity checker results (if any) + private def topKInsights(s: Seq[(FeatureInsights, Insights, Double)], topK: Int): Seq[(String, Double)] = { + s.foldLeft(Seq.empty[(String, Double)]) { + case (acc, (feature, derived, corr)) => + val insightValue = derived.derivedFeatureGroup -> derived.derivedFeatureValue match { + case (Some(group), Some(OpVectorColumnMetadata.NullString)) => s"${feature.featureName}($group = null)" + case (Some(group), Some(TransmogrifierDefaults.OtherString)) => s"${feature.featureName}($group = other)" + case (Some(group), Some(value)) => s"${feature.featureName}($group = $value)" + case (Some(group), None) => s"${feature.featureName}(group = $group)" // should not happen + case (None, Some(value)) => s"${feature.featureName}(value = $value)" // should not happen + case (None, None) => feature.featureName + } + if (acc.exists(_._1 == insightValue)) acc else acc :+ (insightValue, corr) + } take topK + } - summary.mkString("\n") + private def topKCorrelations(insights: ModelInsights, topK: Int): Seq[String] = { + val maxCorrs = insights.features + .flatMap(f => f.derivedFeatures.map(d => (f, d, d.corr.getOrElse(Double.MinValue)))).sortBy(-_._3) + val minCorrs = insights.features + .flatMap(f => f.derivedFeatures.map(d => (f, d, d.corr.getOrElse(Double.MaxValue)))).sortBy(_._3) + val topPositiveInsights = topKInsights(maxCorrs, topK) + val topNegativeInsights = topKInsights(minCorrs, topK).filterNot(topPositiveInsights.contains) + + val correlationCol = "Correlation Value" + + lazy val topPositive = Table( + name = "Top Model Insights", + columns = Seq("Top Positive Correlations", correlationCol), + rows = topPositiveInsights + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + lazy val topNegative = Table( + columns = Seq("Top Negative Correlations", correlationCol), + rows = topNegativeInsights + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + if (topNegativeInsights.isEmpty) Seq(topPositive) else Seq(topPositive, topNegative) + } + + private def topKContributions(insights: ModelInsights, topK: Int): String = { + val maxContribFeatures = insights.features + .flatMap(f => f.derivedFeatures.map(d => + (f, d, d.contribution.reduceOption[Double](math.max).getOrElse(Double.MinValue)))) + .sortBy(v => -1 * math.abs(v._3)) + val rows = topKInsights(maxContribFeatures, topK) + val contributionCol = "Contribution Value" + + Table(columns = Seq("Top Contributions", contributionCol), rows = rows) + .prettyString(columnAlignments = Map(contributionCol -> Right)) + } + + private def topKCramersV(insights: ModelInsights, topK: Int): Option[String] = { + val allCramersV = for { + feature <- insights.features + derived <- feature.derivedFeatures + group <- derived.derivedFeatureGroup + cramersV <- derived.cramersV + } yield group -> cramersV + + val rows = allCramersV.sortBy(-_._2).take(topK) + val cramersVCol = "CramersV" + if (rows.isEmpty) None + else Some(Table(columns = Seq("Top CramersV", cramersVCol), rows = rows) + .prettyString(columnAlignments = Map(cramersVCol -> Right))) } /** From 7642841f65634e934d3acd6e5c208024c1408e35 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Mon, 25 Jun 2018 20:55:25 -0700 Subject: [PATCH 18/19] cleanup --- .../com/salesforce/op/OpWorkflowModel.scala | 32 +++++++------------ .../com/salesforce/op/OpWorkflowTest.scala | 29 +++++++++-------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 31521f8630..08327712f7 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -208,7 +208,7 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams summary += selectedModelInfo(insights) summary += modelEvaluationMetrics(insights) summary ++= topKCorrelations(insights, topK) - summary += topKContributions(insights, topK) + summary ++= topKContributions(insights, topK) summary ++= topKCramersV(insights, topK) summary.mkString("\n") } @@ -261,19 +261,15 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams } private def modelEvaluationMetrics(insights: ModelInsights): String = { - def stringOf: PartialFunction[Any, String] = { - case s: Traversable[_] => s.map(_.toString).mkString("[",",","]") - case v: Any => v.toString - } val name = "Model Evaluation Metrics" - val trainEvaluationMetrics = insights.selectedModelTrainEvalMetrics - val testEvaluationMetrics = insights.selectedModelTestEvalMetrics + val trainEvalMetrics = insights.selectedModelTrainEvalMetrics + val testEvalMetrics = insights.selectedModelTestEvalMetrics val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") - val trainMetrics = trainEvaluationMetrics.toMap.map { case (k, v) => k -> stringOf(v) }.toSeq.sortBy(_._1) - val table = testEvaluationMetrics match { + val trainMetrics = trainEvalMetrics.toMap.collect { case (k, v: Double) => k -> v.toString }.toSeq.sortBy(_._1) + val table = testEvalMetrics match { case Some(testMetrics) => val testMetricsMap = testMetrics.toMap - val rows = trainMetrics.map { case (k, v) => (k, stringOf(v), stringOf(testMetricsMap(k))) } + val rows = trainMetrics.map { case (k, v) => (k, v, testMetricsMap(k).toString) } Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) case None => Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) @@ -320,16 +316,13 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams if (topNegativeInsights.isEmpty) Seq(topPositive) else Seq(topPositive, topNegative) } - private def topKContributions(insights: ModelInsights, topK: Int): String = { + private def topKContributions(insights: ModelInsights, topK: Int): Option[String] = { val maxContribFeatures = insights.features .flatMap(f => f.derivedFeatures.map(d => (f, d, d.contribution.reduceOption[Double](math.max).getOrElse(Double.MinValue)))) .sortBy(v => -1 * math.abs(v._3)) val rows = topKInsights(maxContribFeatures, topK) - val contributionCol = "Contribution Value" - - Table(columns = Seq("Top Contributions", contributionCol), rows = rows) - .prettyString(columnAlignments = Map(contributionCol -> Right)) + numericalTable(columns = Seq("Top Contributions", "Contribution Value"), rows) } private def topKCramersV(insights: ModelInsights, topK: Int): Option[String] = { @@ -339,14 +332,13 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams group <- derived.derivedFeatureGroup cramersV <- derived.cramersV } yield group -> cramersV - val rows = allCramersV.sortBy(-_._2).take(topK) - val cramersVCol = "CramersV" - if (rows.isEmpty) None - else Some(Table(columns = Seq("Top CramersV", cramersVCol), rows = rows) - .prettyString(columnAlignments = Map(cramersVCol -> Right))) + numericalTable(columns = Seq("Top CramersV", "CramersV"), rows) } + private def numericalTable(columns: Seq[String], rows: Seq[(String, Double)]): Option[String] = + if (rows.isEmpty) None else Some(Table(columns, rows).prettyString(columnAlignments = Map(columns.last -> Right))) + /** * Save this model to a path * diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index fe488f7fb3..c1a9c273e5 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -41,17 +41,16 @@ import com.salesforce.op.stages.base.unary._ import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry._ import com.salesforce.op.stages.impl.classification._ import com.salesforce.op.stages.impl.preparators.SanityChecker -import com.salesforce.op.stages.impl.regression.{LossType, RegressionModelSelector, RegressionModelsToTry} -import com.salesforce.op.stages.impl.selector.{ModelSelectorBaseNames, SelectedModel} +import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames import com.salesforce.op.stages.impl.tuning._ -import com.salesforce.op.test.{Passenger, PassengerCSV, PassengerSparkFixtureTest, TestFeatureBuilder} +import com.salesforce.op.test.{Passenger, PassengerSparkFixtureTest, TestFeatureBuilder} import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import org.apache.spark.ml.param.BooleanParam import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{DoubleType, StringType} import org.apache.spark.sql.{Dataset, SparkSession} -import org.joda.time.{DateTime, Duration} +import org.joda.time.DateTime import org.junit.runner.RunWith import org.scalatest.FlatSpec import org.scalatest.junit.JUnitRunner @@ -391,18 +390,22 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { val summary = fittedWorkflow.summary() log.info(summary) - summary.contains(classOf[SanityChecker].getSimpleName) shouldBe true - summary.contains("logreg") shouldBe true - summary.contains(""""regParam" : "0.1"""") shouldBe true - summary.contains(""""regParam" : "0.01"""") shouldBe true - summary.contains(ModelSelectorBaseNames.HoldOutEval) shouldBe true - summary.contains(ModelSelectorBaseNames.TrainingEval) shouldBe true + summary should include(classOf[SanityChecker].getSimpleName) + summary should include("logreg") + summary should include(""""regParam" : "0.1"""") + summary should include(""""regParam" : "0.01"""") + summary should include(ModelSelectorBaseNames.HoldOutEval) + summary should include(ModelSelectorBaseNames.TrainingEval) val prettySummary = fittedWorkflow.summaryPretty() log.info(prettySummary) - prettySummary.contains(s"Selected Model - $LogisticRegression") shouldBe true - prettySummary.contains("| area under PR | 0.25") shouldBe true - prettySummary.contains("Model Evaluation Metrics") shouldBe true + prettySummary should include(s"Selected Model - $LogisticRegression") + prettySummary should include("| area under PR | 0.25") + prettySummary should include("Model Evaluation Metrics") + prettySummary should include("Top Model Insights") + prettySummary should include("Top Model Insights") + prettySummary should include("Top Positive Correlations") + prettySummary should include("Top Contributions") } it should "be able to refit a workflow with calibrated probability" in { From 0177aea0c8394aa62ca51aead7e5656d4bb3caf7 Mon Sep 17 00:00:00 2001 From: Matthew Tovbin Date: Fri, 29 Jun 2018 11:22:17 -0700 Subject: [PATCH 19/19] move code to model insights --- .../com/salesforce/op/ModelInsights.scala | 160 +++++++++++++++++- .../com/salesforce/op/OpWorkflowModel.scala | 154 +---------------- .../com/salesforce/op/ModelInsightsTest.scala | 11 ++ .../com/salesforce/op/OpWorkflowTest.scala | 1 - 4 files changed, 179 insertions(+), 147 deletions(-) diff --git a/core/src/main/scala/com/salesforce/op/ModelInsights.scala b/core/src/main/scala/com/salesforce/op/ModelInsights.scala index d0ba39f0e2..1caac95574 100644 --- a/core/src/main/scala/com/salesforce/op/ModelInsights.scala +++ b/core/src/main/scala/com/salesforce/op/ModelInsights.scala @@ -37,6 +37,7 @@ import com.salesforce.op.features.types.{OPVector, RealNN} import com.salesforce.op.stages.impl.ModelsToTry import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry import com.salesforce.op.stages.impl.classification.ClassificationModelsToTry.{DecisionTree, LogisticRegression, NaiveBayes, RandomForest} +import com.salesforce.op.stages.impl.feature.TransmogrifierDefaults import com.salesforce.op.stages.impl.preparators._ import com.salesforce.op.stages.impl.regression.RegressionModelsToTry import com.salesforce.op.stages.impl.regression.RegressionModelsToTry.{DecisionTreeRegression, GBTRegression, LinearRegression, RandomForestRegression} @@ -44,8 +45,9 @@ import com.salesforce.op.stages.impl.selector.ModelSelectorBaseNames._ import com.salesforce.op.stages.impl.selector.{ModelSelectorBase, SelectedModel} import com.salesforce.op.stages.{OPStage, OpPipelineStageParams, OpPipelineStageParamsNames} import com.salesforce.op.utils.json.JsonUtils -import com.salesforce.op.utils.spark.OpVectorMetadata +import com.salesforce.op.utils.spark.{OpVectorColumnMetadata, OpVectorMetadata} import com.salesforce.op.utils.spark.RichMetadata._ +import com.salesforce.op.utils.table.Table import enumeratum._ import org.apache.spark.ml.classification._ import org.apache.spark.ml.regression._ @@ -56,7 +58,9 @@ import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization.{write, writePretty} import org.slf4j.LoggerFactory +import com.salesforce.op.utils.table.Alignment._ +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -190,6 +194,160 @@ case class ModelInsights if (pretty) writePretty(this) else write(this) } + /** + * High level model summary in a compact print friendly format containing: + * selected model info, model evaluation results and feature correlations/contributions/cramersV values. + * + * @param topK top K of feature correlations/contributions/cramersV values + * @return high level model summary in a compact print friendly format + */ + def prettyPrint(topK: Int = 15): String = { + val res = new ArrayBuffer[String]() + res ++= prettyValidationResults + res += prettySelectedModelInfo + res += modelEvaluationMetrics + res ++= topKCorrelations(topK) + res ++= topKContributions(topK) + res ++= topKCramersV(topK) + res.mkString("\n") + } + + private def prettyValidationResults: Seq[String] = { + val evalSummary = { + val vModelTypes = validatedModelTypes + "Evaluated %s model%s using %s and %s metric.".format( + vModelTypes.mkString(", "), + if (vModelTypes.size > 1) "s" else "", + validationType.humanFriendlyName, // TODO add number of folds or train/split ratio if possible + evaluationMetricType.humanFriendlyName + ) + } + val modelEvalRes = for { + modelType <- validatedModelTypes + modelValidationResults = validationResults(modelType) + evalMetric = evaluationMetricType.humanFriendlyName + } yield { + val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => + metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) + } + val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) + val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) + + "Evaluated %d %s model%s with %s metric between [%s, %s].".format( + modelValidationResults.size, + modelType, + if (modelValidationResults.size > 1) "s" else "", + evalMetric, + minMetricValue, + maxMetricValue + ) + } + Seq(evalSummary, modelEvalRes.mkString("\n")) + } + + private def prettySelectedModelInfo: String = { + val bestModelType = selectedModelType + val name = s"Selected Model - $bestModelType" + val validationResults = selectedModelValidationResults.toSeq ++ Seq( + "name" -> selectedModelName, + "uid" -> selectedModelUID, + "modelType" -> selectedModelType + ) + val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) + table.prettyString() + } + + private def modelEvaluationMetrics: String = { + val name = "Model Evaluation Metrics" + val trainEvalMetrics = selectedModelTrainEvalMetrics + val testEvalMetrics = selectedModelTestEvalMetrics + val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") + val trainMetrics = trainEvalMetrics.toMap.collect { case (k, v: Double) => k -> v.toString }.toSeq.sortBy(_._1) + val table = testEvalMetrics match { + case Some(testMetrics) => + val testMetricsMap = testMetrics.toMap + val rows = trainMetrics.map { case (k, v) => (k, v, testMetricsMap(k).toString) } + Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) + case None => + Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) + } + table.prettyString() + } + + private def topKInsights(s: Seq[(FeatureInsights, Insights, Double)], topK: Int): Seq[(String, Double)] = { + s.foldLeft(Seq.empty[(String, Double)]) { + case (acc, (feature, derived, corr)) => + val insightValue = derived.derivedFeatureGroup -> derived.derivedFeatureValue match { + case (Some(group), Some(OpVectorColumnMetadata.NullString)) => s"${feature.featureName}($group = null)" + case (Some(group), Some(TransmogrifierDefaults.OtherString)) => s"${feature.featureName}($group = other)" + case (Some(group), Some(value)) => s"${feature.featureName}($group = $value)" + case (Some(group), None) => s"${feature.featureName}(group = $group)" // should not happen + case (None, Some(value)) => s"${feature.featureName}(value = $value)" // should not happen + case (None, None) => feature.featureName + } + if (acc.exists(_._1 == insightValue)) acc else acc :+ (insightValue, corr) + } take topK + } + + private def topKCorrelations(topK: Int): Seq[String] = { + val corrs = for { + (feature, derived) <- derivedNonExcludedFeatures + } yield (feature, derived, derived.corr.collect { case v if !v.isNaN => v }) + + val corrDsc = corrs.map { case (f, d, corr) => (f, d, corr.getOrElse(Double.MinValue)) }.sortBy(_._3).reverse + val corrAsc = corrs.map { case (f, d, corr) => (f, d, corr.getOrElse(Double.MaxValue)) }.sortBy(_._3) + val topPositiveCorrs = topKInsights(corrDsc, topK) + val topNegativeCorrs = topKInsights(corrAsc, topK).filterNot(topPositiveCorrs.contains) + + val correlationCol = "Correlation Value" + + lazy val topPositive = Table( + name = "Top Model Insights", + columns = Seq("Top Positive Correlations", correlationCol), + rows = topPositiveCorrs + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + lazy val topNegative = Table( + columns = Seq("Top Negative Correlations", correlationCol), + rows = topNegativeCorrs + ).prettyString(columnAlignments = Map(correlationCol -> Right)) + + if (topNegativeCorrs.isEmpty) Seq(topPositive) else Seq(topPositive, topNegative) + } + + private def topKContributions(topK: Int): Option[String] = { + val contribs = for { + (feature, derived) <- derivedNonExcludedFeatures + contrib = math.abs(derived.contribution.reduceOption[Double](math.max).getOrElse(0.0)) + } yield (feature, derived, contrib) + + val contribDesc = contribs.sortBy(_._3).reverse + val rows = topKInsights(contribDesc, topK) + numericalTable(columns = Seq("Top Contributions", "Contribution Value"), rows) + } + + private def topKCramersV(topK: Int): Option[String] = { + val cramersV = for { + (feature, derived) <- derivedNonExcludedFeatures + group <- derived.derivedFeatureGroup + cramersV <- derived.cramersV + } yield group -> cramersV + + val topCramersV = cramersV.distinct.sortBy(_._2).reverse.take(topK) + numericalTable(columns = Seq("Top CramersV", "CramersV"), rows = topCramersV) + } + + private def derivedNonExcludedFeatures: Seq[(FeatureInsights, Insights)] = { + for { + feature <- features + derived <- feature.derivedFeatures + if !derived.excluded.contains(true) + } yield feature -> derived + } + + private def numericalTable(columns: Seq[String], rows: Seq[(String, Double)]): Option[String] = + if (rows.isEmpty) None else Some(Table(columns, rows).prettyString(columnAlignments = Map(columns.last -> Right))) + private def modelType(modelName: String): Try[ModelsToTry] = Try { classificationModelType.orElse(regressionModelType).lift(modelName).getOrElse( throw new Exception(s"Unsupported model type for best model '$modelName'")) diff --git a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala index 08327712f7..6c23e0acf3 100644 --- a/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala +++ b/core/src/main/scala/com/salesforce/op/OpWorkflowModel.scala @@ -35,23 +35,17 @@ import com.salesforce.op.evaluators.{EvaluationMetrics, OpEvaluatorBase} import com.salesforce.op.features.types.FeatureType import com.salesforce.op.features.{FeatureLike, OPFeature} import com.salesforce.op.readers.DataFrameFieldNames._ -import com.salesforce.op.stages.impl.feature.TransmogrifierDefaults import com.salesforce.op.stages.{OPStage, OpPipelineStage, OpTransformer} -import com.salesforce.op.utils.spark.OpVectorColumnMetadata import com.salesforce.op.utils.spark.RichDataset._ import com.salesforce.op.utils.spark.RichMetadata._ import com.salesforce.op.utils.stages.FitStagesUtil -import com.salesforce.op.utils.table.Alignment._ -import com.salesforce.op.utils.table._ import org.apache.spark.sql.types.Metadata import org.apache.spark.sql.{DataFrame, SparkSession} import org.json4s.JValue import org.json4s.JsonAST.{JField, JObject} import org.json4s.jackson.JsonMethods.{pretty, render} -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import scala.util.Try /** @@ -197,147 +191,17 @@ class OpWorkflowModel(val uid: String = UID[OpWorkflowModel], val trainingParams * High level model summary in a compact print friendly format containing: * selected model info, model evaluation results and feature correlations/contributions/cramersV values. * - * @param topK top K of feature correlations/contributions/cramersV values + * @param insights model insights to compute the summary against + * @param topK top K of feature correlations/contributions/cramersV values to print * @return high level model summary in a compact print friendly format */ - def summaryPretty(topK: Int = 15): String = { - val response = resultFeatures.find(_.isResponse).getOrElse(throw new Exception("No response feature is defined")) - val insights = modelInsights(response) - val summary = new ArrayBuffer[String]() - summary ++= validationResults(insights) - summary += selectedModelInfo(insights) - summary += modelEvaluationMetrics(insights) - summary ++= topKCorrelations(insights, topK) - summary ++= topKContributions(insights, topK) - summary ++= topKCramersV(insights, topK) - summary.mkString("\n") - } - - private def validationResults(insights: ModelInsights): Seq[String] = { - val evalSummary = { - val validatedModelTypes = insights.validatedModelTypes - val validationType = insights.validationType.humanFriendlyName - val evalMetric = insights.evaluationMetricType.humanFriendlyName - "Evaluated %s model%s using %s and %s metric.".format( - validatedModelTypes.mkString(", "), - if (validatedModelTypes.size > 1) "s" else "", - validationType, // TODO add number of folds or train/split ratio if possible - evalMetric - ) - } - val modelEvalRes = for { - modelType <- insights.validatedModelTypes - modelValidationResults = insights.validationResults(modelType) - evalMetric = insights.evaluationMetricType.humanFriendlyName - } yield { - val evalMetricValues = modelValidationResults.flatMap { case (_, metrics) => - metrics.get(evalMetric).flatMap(v => Try(v.toDouble).toOption) - } - val minMetricValue = evalMetricValues.reduceOption[Double](math.min).getOrElse(Double.NaN) - val maxMetricValue = evalMetricValues.reduceOption[Double](math.max).getOrElse(Double.NaN) - - "Evaluated %d %s model%s with %s metric between [%s, %s].".format( - modelValidationResults.size, - modelType, - if (modelValidationResults.size > 1) "s" else "", - evalMetric, - minMetricValue, - maxMetricValue - ) - } - Seq(evalSummary, modelEvalRes.mkString("\n")) - } - - private def selectedModelInfo(insights: ModelInsights): String = { - val bestModelType = insights.selectedModelType - val name = s"Selected Model - $bestModelType" - val validationResults = insights.selectedModelValidationResults.toSeq ++ Seq( - "name" -> insights.selectedModelName, - "uid" -> insights.selectedModelUID, - "modelType" -> insights.selectedModelType - ) - val table = Table(name = name, columns = Seq("Model Param", "Value"), rows = validationResults.sortBy(_._1)) - table.prettyString() - } - - private def modelEvaluationMetrics(insights: ModelInsights): String = { - val name = "Model Evaluation Metrics" - val trainEvalMetrics = insights.selectedModelTrainEvalMetrics - val testEvalMetrics = insights.selectedModelTestEvalMetrics - val (metricNameCol, holdOutCol, trainingCol) = ("Metric Name", "Hold Out Set Value", "Training Set Value") - val trainMetrics = trainEvalMetrics.toMap.collect { case (k, v: Double) => k -> v.toString }.toSeq.sortBy(_._1) - val table = testEvalMetrics match { - case Some(testMetrics) => - val testMetricsMap = testMetrics.toMap - val rows = trainMetrics.map { case (k, v) => (k, v, testMetricsMap(k).toString) } - Table(name = name, columns = Seq(metricNameCol, trainingCol, holdOutCol), rows = rows) - case None => - Table(name = name, columns = Seq(metricNameCol, trainingCol), rows = trainMetrics) - } - table.prettyString() - } - - private def topKInsights(s: Seq[(FeatureInsights, Insights, Double)], topK: Int): Seq[(String, Double)] = { - s.foldLeft(Seq.empty[(String, Double)]) { - case (acc, (feature, derived, corr)) => - val insightValue = derived.derivedFeatureGroup -> derived.derivedFeatureValue match { - case (Some(group), Some(OpVectorColumnMetadata.NullString)) => s"${feature.featureName}($group = null)" - case (Some(group), Some(TransmogrifierDefaults.OtherString)) => s"${feature.featureName}($group = other)" - case (Some(group), Some(value)) => s"${feature.featureName}($group = $value)" - case (Some(group), None) => s"${feature.featureName}(group = $group)" // should not happen - case (None, Some(value)) => s"${feature.featureName}(value = $value)" // should not happen - case (None, None) => feature.featureName - } - if (acc.exists(_._1 == insightValue)) acc else acc :+ (insightValue, corr) - } take topK - } - - private def topKCorrelations(insights: ModelInsights, topK: Int): Seq[String] = { - val maxCorrs = insights.features - .flatMap(f => f.derivedFeatures.map(d => (f, d, d.corr.getOrElse(Double.MinValue)))).sortBy(-_._3) - val minCorrs = insights.features - .flatMap(f => f.derivedFeatures.map(d => (f, d, d.corr.getOrElse(Double.MaxValue)))).sortBy(_._3) - val topPositiveInsights = topKInsights(maxCorrs, topK) - val topNegativeInsights = topKInsights(minCorrs, topK).filterNot(topPositiveInsights.contains) - - val correlationCol = "Correlation Value" - - lazy val topPositive = Table( - name = "Top Model Insights", - columns = Seq("Top Positive Correlations", correlationCol), - rows = topPositiveInsights - ).prettyString(columnAlignments = Map(correlationCol -> Right)) - - lazy val topNegative = Table( - columns = Seq("Top Negative Correlations", correlationCol), - rows = topNegativeInsights - ).prettyString(columnAlignments = Map(correlationCol -> Right)) - - if (topNegativeInsights.isEmpty) Seq(topPositive) else Seq(topPositive, topNegative) - } - - private def topKContributions(insights: ModelInsights, topK: Int): Option[String] = { - val maxContribFeatures = insights.features - .flatMap(f => f.derivedFeatures.map(d => - (f, d, d.contribution.reduceOption[Double](math.max).getOrElse(Double.MinValue)))) - .sortBy(v => -1 * math.abs(v._3)) - val rows = topKInsights(maxContribFeatures, topK) - numericalTable(columns = Seq("Top Contributions", "Contribution Value"), rows) - } - - private def topKCramersV(insights: ModelInsights, topK: Int): Option[String] = { - val allCramersV = for { - feature <- insights.features - derived <- feature.derivedFeatures - group <- derived.derivedFeatureGroup - cramersV <- derived.cramersV - } yield group -> cramersV - val rows = allCramersV.sortBy(-_._2).take(topK) - numericalTable(columns = Seq("Top CramersV", "CramersV"), rows) - } - - private def numericalTable(columns: Seq[String], rows: Seq[(String, Double)]): Option[String] = - if (rows.isEmpty) None else Some(Table(columns, rows).prettyString(columnAlignments = Map(columns.last -> Right))) + def summaryPretty( + insights: ModelInsights = modelInsights( + resultFeatures.find(f => f.isResponse && !f.isRaw).getOrElse( + throw new IllegalArgumentException("No response feature is defined to compute model insights")) + ), + topK: Int = 15 + ): String = insights.prettyPrint(topK) /** * Save this model to a path diff --git a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala index 8f0a47c49c..8069c87d0d 100644 --- a/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala +++ b/core/src/test/scala/com/salesforce/op/ModelInsightsTest.scala @@ -270,6 +270,17 @@ class ModelInsightsTest extends FlatSpec with PassengerSparkFixtureTest { ) } + it should "pretty print" in { + val insights = workflowModel.modelInsights(prob) + val pretty = insights.prettyPrint() + pretty should include(s"Selected Model - $LogisticRegression") + pretty should include("| area under PR | 0.0") + pretty should include("Model Evaluation Metrics") + pretty should include("Top Model Insights") + pretty should include("Top Positive Correlations") + pretty should include("Top Contributions") + } + it should "correctly serialize and deserialize from json" in { val insights = workflowModel.modelInsights(prob) ModelInsights.fromJson(insights.toJson()) match { diff --git a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala index c1a9c273e5..40e7f74ca0 100644 --- a/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala +++ b/core/src/test/scala/com/salesforce/op/OpWorkflowTest.scala @@ -403,7 +403,6 @@ class OpWorkflowTest extends FlatSpec with PassengerSparkFixtureTest { prettySummary should include("| area under PR | 0.25") prettySummary should include("Model Evaluation Metrics") prettySummary should include("Top Model Insights") - prettySummary should include("Top Model Insights") prettySummary should include("Top Positive Correlations") prettySummary should include("Top Contributions") }