Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary stat #803

Merged
merged 26 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7886cbc
First implementation of the summary stat.
ASmirnov-HORIS Jun 16, 2023
d95a062
Update API for the stat_summary() function.
ASmirnov-HORIS Jun 20, 2023
b28bf7f
Small refactor in SummaryStat.
ASmirnov-HORIS Jun 21, 2023
6ce963b
Remove extra enum class from the SummaryStatUtil.
ASmirnov-HORIS Jun 21, 2023
b7a8f33
Remove SummaryStatUtil.
ASmirnov-HORIS Jun 21, 2023
05345ee
Use references instead of lambdas for the SummaryCalculator.
ASmirnov-HORIS Jun 21, 2023
756b626
Replace SummaryCalculator by the SummaryStatUtil.
ASmirnov-HORIS Jun 21, 2023
3b2bca7
Refactor functions in SummaryStatUtil.
ASmirnov-HORIS Jun 22, 2023
1f03da4
Refactor summary stat options in StatProto.
ASmirnov-HORIS Jun 22, 2023
baccb49
Fix statData emptiness case in the SummaryStat.
ASmirnov-HORIS Jun 22, 2023
c9e154c
Further code refactoring.
ASmirnov-HORIS Jun 22, 2023
0082f74
Small fixes.
ASmirnov-HORIS Jun 22, 2023
cbabbe3
Add new stat variables and use them in the SummaryStat.
ASmirnov-HORIS Jun 23, 2023
7fea03b
Add prefix to min/max stats in stat_summary().
ASmirnov-HORIS Jun 23, 2023
0ecc6ec
Change API of the summary_stat() - add 'quantiles' parameter.
ASmirnov-HORIS Jun 26, 2023
722b43a
Tiny refactor in SummaryStat and AggregateFunctions.
ASmirnov-HORIS Jun 27, 2023
e30b93b
Use AggregateFunctions in the FiveNumberSummary.
ASmirnov-HORIS Jun 27, 2023
f076938
Add tests for AggregateFunctions.
ASmirnov-HORIS Jun 27, 2023
2313872
Replace parameter fun_map by usual aesthetics list for the stat_summa…
ASmirnov-HORIS Jun 29, 2023
4bba413
Small fixes in code for summary stat.
ASmirnov-HORIS Jun 29, 2023
9cfeea5
Merge branch 'master' into summary-stats
ASmirnov-HORIS Jun 29, 2023
4586451
Add getMapping() method to the Flipped stat context.
ASmirnov-HORIS Jun 29, 2023
9dcdd79
Add docstrings to the stat_summary() function.
ASmirnov-HORIS Jun 29, 2023
37526eb
Add demo notebook for stat_summary().
ASmirnov-HORIS Jun 29, 2023
0a1d636
Mention stat_summary() in the future_changes.
ASmirnov-HORIS Jun 29, 2023
a392ba1
Refactor StatContext.
ASmirnov-HORIS Jun 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
565 changes: 565 additions & 0 deletions docs/f-23c/stat_summary.ipynb

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions future_changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

### Added

- New layer `stat_summary()`.

See: [example notebook](https://nbviewer.org/github/JetBrains/lets-plot/blob/master/docs/f-23c/stat_summary.ipynb).


- Tooltips for `geom_step()`.

See: [example notebook](https://nbviewer.org/github/JetBrains/lets-plot/blob/master/docs/f-23c/geom_step_tooltips.ipynb).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ interface StatContext {

fun overallYRange(): DoubleSpan?

fun mappedStatVariables(): List<DataFrame.Variable>

fun getFlipped(): StatContext {
return Flipped(this)
}
Expand All @@ -25,6 +27,10 @@ interface StatContext {
return orig.overallXRange()
}

override fun mappedStatVariables(): List<DataFrame.Variable> {
return orig.mappedStatVariables()
}

override fun getFlipped(): StatContext {
return orig
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2023. JetBrains s.r.o.
* Use of this source code is governed by the MIT license that can be found in the LICENSE file.
*/

package jetbrains.datalore.plot.base.stat

import kotlin.math.ceil
import kotlin.math.floor
import kotlin.math.round

object AggregateFunctions {
fun count(values: List<Double>): Double = values.size.toDouble()

fun sum(values: List<Double>): Double {
return when (values.size) {
0 -> Double.NaN
else -> values.sum()
}
}

fun mean(values: List<Double>): Double {
return when (values.size) {
0 -> Double.NaN
else -> sum(values) / count(values)
}
}

fun median(sortedValues: List<Double>): Double = quantile(sortedValues, 0.5)

fun min(sortedValues: List<Double>): Double = sortedValues.firstOrNull() ?: Double.NaN

fun max(sortedValues: List<Double>): Double = sortedValues.lastOrNull() ?: Double.NaN

fun quantile(sortedValues: List<Double>, p: Double): Double {
if (sortedValues.isEmpty()) {
return Double.NaN
}
val place = p * (sortedValues.size - 1)
return when (round(place)) {
place -> sortedValues[place.toInt()]
else -> (sortedValues[ceil(place).toInt()] + sortedValues[floor(place).toInt()]) / 2.0
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
package jetbrains.datalore.plot.base.stat

import jetbrains.datalore.base.gcommon.collect.Ordering
import kotlin.math.ceil
import kotlin.math.floor
import kotlin.math.round

/**
* For a set of data, the minimum, first quartile, median, third quartile, and maximum.
Expand All @@ -23,36 +20,13 @@ internal class FiveNumberSummary {
// 25 %
val thirdQuartile: Double // 75 %

private fun medianAtPointer(l: List<Double>, pointer: Double): Double {
val rint = round(pointer)
return if (pointer == rint) {
l[pointer.toInt()]
} else (l[ceil(pointer).toInt()] + l[floor(pointer).toInt()]) / 2.0
}

constructor(data: List<Double>) {
val sorted = Ordering.natural<Double>().sortedCopy(data)
if (sorted.isEmpty()) {
thirdQuartile = Double.NaN
firstQuartile = thirdQuartile
median = firstQuartile
max = median
min = max
} else if (sorted.size == 1) {
thirdQuartile = sorted.get(0)
firstQuartile = thirdQuartile
median = firstQuartile
max = median
min = max
} else {
val maxIndex = sorted.size - 1

min = sorted.get(0)
max = sorted.get(maxIndex)
median = medianAtPointer(sorted, maxIndex * .5)
firstQuartile = medianAtPointer(sorted, maxIndex * .25)
thirdQuartile = medianAtPointer(sorted, maxIndex * .75)
}
min = AggregateFunctions.min(sorted)
max = AggregateFunctions.max(sorted)
median = AggregateFunctions.median(sorted)
firstQuartile = AggregateFunctions.quantile(sorted, 0.25)
thirdQuartile = AggregateFunctions.quantile(sorted, 0.75)
}

constructor(min: Double, max: Double, median: Double, firstQuartile: Double, thirdQuartile: Double) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import jetbrains.datalore.plot.base.DataFrame
import jetbrains.datalore.plot.base.StatContext
import jetbrains.datalore.plot.base.data.TransformVar

class SimpleStatContext(private val myDataFrame: DataFrame) :
class SimpleStatContext(private val myDataFrame: DataFrame, private val mappedStatVariables: List<DataFrame.Variable>) :
StatContext {

override fun overallXRange(): DoubleSpan? {
Expand All @@ -20,4 +20,8 @@ class SimpleStatContext(private val myDataFrame: DataFrame) :
override fun overallYRange(): DoubleSpan? {
return myDataFrame.range(TransformVar.Y)
}

override fun mappedStatVariables(): List<DataFrame.Variable> {
return mappedStatVariables
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ object Stats {
val THEORETICAL = DataFrame.Variable("..theoretical..", STAT, "theoretical")
val SE = DataFrame.Variable("..se..", STAT, "standard error")
val LEVEL = DataFrame.Variable("..level..", STAT, "level")
val MEAN = DataFrame.Variable("..mean..", STAT, "mean")
val MEDIAN = DataFrame.Variable("..median..", STAT, "median")
val QUANTILE = DataFrame.Variable("..quantile..", STAT, "quantile")
val LOWER_QUANTILE = DataFrame.Variable("..lq..", STAT, "lower quantile")
val MIDDLE_QUANTILE = DataFrame.Variable("..mq..", STAT, "middle quantile")
val UPPER_QUANTILE = DataFrame.Variable("..uq..", STAT, "upper quantile")
val LOWER = DataFrame.Variable("..lower..", STAT, "lower")
val MIDDLE = DataFrame.Variable("..middle..", STAT, "middle")
val UPPER = DataFrame.Variable("..upper..", STAT, "upper")
Expand Down Expand Up @@ -53,7 +58,12 @@ object Stats {
THEORETICAL,
SE,
LEVEL,
MEAN,
MEDIAN,
QUANTILE,
LOWER_QUANTILE,
MIDDLE_QUANTILE,
UPPER_QUANTILE,
LOWER,
MIDDLE,
UPPER,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Copyright (c) 2023. JetBrains s.r.o.
* Use of this source code is governed by the MIT license that can be found in the LICENSE file.
*/

package jetbrains.datalore.plot.base.stat

import jetbrains.datalore.base.gcommon.collect.Ordering
import jetbrains.datalore.plot.base.Aes
import jetbrains.datalore.plot.base.DataFrame
import jetbrains.datalore.plot.base.StatContext
import jetbrains.datalore.plot.base.data.TransformVar
import jetbrains.datalore.plot.common.data.SeriesUtil

class SummaryStat(
private val yAggFunction: (List<Double>) -> Double,
private val yMinAggFunction: (List<Double>) -> Double,
private val yMaxAggFunction: (List<Double>) -> Double,
private val sortedQuantiles: List<Double>
) : BaseStat(DEF_MAPPING) {

override fun consumes(): List<Aes<*>> {
return listOf(Aes.X, Aes.Y)
}

override fun apply(data: DataFrame, statCtx: StatContext, messageConsumer: (s: String) -> Unit): DataFrame {
if (!hasRequiredValues(data, Aes.Y)) {
return withEmptyStatValues()
}

val ys = data.getNumeric(TransformVar.Y)
val xs = if (data.has(TransformVar.X)) {
data.getNumeric(TransformVar.X)
} else {
List(ys.size) { 0.0 }
}

val statData = buildStat(xs, ys, statCtx)
if (statData.isEmpty()) {
return withEmptyStatValues()
}

val builder = DataFrame.Builder()
for ((variable, series) in statData) {
builder.putNumeric(variable, series)
}
return builder.build()
}

private fun buildStat(
xs: List<Double?>,
ys: List<Double?>,
statCtx: StatContext
): Map<DataFrame.Variable, List<Double>> {
val binnedData = SeriesUtil.filterFinite(xs, ys)
.let { (xs, ys) -> xs zip ys }
.groupBy(keySelector = { it.first }, valueTransform = { it.second })

if (binnedData.isEmpty()) {
return emptyMap()
}

val statX = ArrayList<Double>()
val statY = ArrayList<Double>()
val statYMin = ArrayList<Double>()
val statYMax = ArrayList<Double>()
val statAggValues: Map<DataFrame.Variable, MutableList<Double>> = statCtx.mappedStatVariables()
.associateWith { mutableListOf() }
for ((x, bin) in binnedData) {
val sortedBin = Ordering.natural<Double>().sortedCopy(bin)
statX.add(x)
statY.add(yAggFunction(sortedBin))
statYMin.add(yMinAggFunction(sortedBin))
statYMax.add(yMaxAggFunction(sortedBin))
for ((statVar, aggValues) in statAggValues) {
val aggFunction = aggFunctionByStat(statVar)
aggValues.add(aggFunction(sortedBin))
}
}

return mapOf(
Stats.X to statX,
Stats.Y to statY,
Stats.Y_MIN to statYMin,
Stats.Y_MAX to statYMax,
) + statAggValues
}

private fun aggFunctionByStat(statVar: DataFrame.Variable): (List<Double>) -> Double {
return when (statVar) {
Stats.COUNT -> AggregateFunctions::count
Stats.SUM -> AggregateFunctions::sum
Stats.MEAN -> AggregateFunctions::mean
Stats.MEDIAN -> AggregateFunctions::median
Stats.Y_MIN -> AggregateFunctions::min
Stats.Y_MAX -> AggregateFunctions::max
Stats.LOWER_QUANTILE -> { values -> AggregateFunctions.quantile(values, sortedQuantiles[0]) }
Stats.MIDDLE_QUANTILE -> { values -> AggregateFunctions.quantile(values, sortedQuantiles[1]) }
Stats.UPPER_QUANTILE -> { values -> AggregateFunctions.quantile(values, sortedQuantiles[2]) }
else -> throw IllegalStateException(
"Unsupported stat variable: '${statVar.name}'\n" +
"Use one of: ..count.., ..sum.., ..mean.., ..median.., ..ymin.., ..ymax.., ..lq.., ..mq.., ..uq.."
)
}
}

companion object {
val DEF_QUANTILES = listOf(0.25, 0.5, 0.75)

private val DEF_MAPPING: Map<Aes<*>, DataFrame.Variable> = mapOf(
Aes.X to Stats.X,
Aes.Y to Stats.Y,
Aes.YMIN to Stats.Y_MIN,
Aes.YMAX to Stats.Y_MAX
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright (c) 2023. JetBrains s.r.o.
* Use of this source code is governed by the MIT license that can be found in the LICENSE file.
*/

package jetbrains.datalore.plot.base.stat

import kotlin.test.Test
import kotlin.test.assertEquals

class AggregateFunctionsTest {
@Test
fun emptyData() {
val values: List<Double> = emptyList()
assertEquals(0.0, AggregateFunctions.count(values))
assertEquals(Double.NaN, AggregateFunctions.sum(values))
assertEquals(Double.NaN, AggregateFunctions.mean(values))
assertEquals(Double.NaN, AggregateFunctions.median(values))
assertEquals(Double.NaN, AggregateFunctions.min(values))
assertEquals(Double.NaN, AggregateFunctions.max(values))
assertEquals(Double.NaN, AggregateFunctions.quantile(values, 0.25))
}

@Test
fun oneElementData() {
val value = 1.0
val values = listOf(value)
assertEquals(1.0, AggregateFunctions.count(values))
assertEquals(value, AggregateFunctions.sum(values))
assertEquals(value, AggregateFunctions.mean(values))
assertEquals(value, AggregateFunctions.median(values))
assertEquals(value, AggregateFunctions.min(values))
assertEquals(value, AggregateFunctions.max(values))
assertEquals(value, AggregateFunctions.quantile(values, 0.25))
}

@Test
fun checkCountFunction() {
assertEquals(4.0, AggregateFunctions.count(listOf(-1.0, -1.0, 1.0, 3.0)))
}

@Test
fun checkSumFunction() {
assertEquals(2.0, AggregateFunctions.sum(listOf(-1.0, -1.0, 1.0, 3.0)))
}

@Test
fun checkMeanFunction() {
assertEquals(0.5, AggregateFunctions.mean(listOf(-1.0, -1.0, 1.0, 3.0)))
assertEquals(2.0, AggregateFunctions.mean(listOf(-2.0, 3.0, 5.0)))
}

@Test
fun checkMedianFunction() {
assertEquals(0.0, AggregateFunctions.median(listOf(-1.0, -1.0, 1.0, 3.0)))
assertEquals(3.0, AggregateFunctions.median(listOf(-2.0, 3.0, 5.0)))
}

@Test
fun checkMinFunction() {
assertEquals(-1.0, AggregateFunctions.min(listOf(-1.0, -1.0, 1.0, 3.0)))
}

@Test
fun checkMaxFunction() {
assertEquals(3.0, AggregateFunctions.max(listOf(-1.0, -1.0, 1.0, 3.0)))
}

@Test
fun checkQuantileFunction() {
val sortedValues = listOf(-1.0, -1.0, 1.0, 3.0)
assertEquals(-1.0, AggregateFunctions.quantile(sortedValues, 0.0))
assertEquals(-1.0, AggregateFunctions.quantile(sortedValues, 0.25))
assertEquals(-1.0, AggregateFunctions.quantile(sortedValues, 1.0 / 3.0))
assertEquals(0.0, AggregateFunctions.quantile(sortedValues, 0.5))
assertEquals(1.0, AggregateFunctions.quantile(sortedValues, 2.0 / 3.0))
assertEquals(2.0, AggregateFunctions.quantile(sortedValues, 0.75))
assertEquals(3.0, AggregateFunctions.quantile(sortedValues, 1.0))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import kotlin.test.assertTrue

open class BaseStatTest {
protected fun statContext(d: DataFrame): StatContext {
return SimpleStatContext(d)
return SimpleStatContext(d, emptyList())
}

protected fun dataFrame(dataMap: Map<DataFrame.Variable, List<Double?>>): DataFrame {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import kotlin.test.assertTrue
class BoxplotStatTest {

private fun statContext(d: DataFrame): StatContext {
return SimpleStatContext(d)
return SimpleStatContext(d, emptyList())
}

private fun df(m: Map<DataFrame.Variable, List<Double>>): DataFrame {
Expand Down
Loading