Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Legend override_aes #1115

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
323 changes: 323 additions & 0 deletions docs/f-24e/legend_override_aes.ipynb

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions future_changes.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
## [4.3.4] - 2024-mm-dd

### Added
- Legend title in guide_legend() and guide_colorbar().
- Legend title in `guide_legend()` and `guide_colorbar()`.
See [example notebook](https://nbviewer.org/github/JetBrains/lets-plot/blob/master/docs/f-24e/legend_title.ipynb).

- Parameter `override_aes` in `guide_legend()`.
See [example notebook](https://nbviewer.org/github/JetBrains/lets-plot/blob/master/docs/f-24e/legend_override_aes.ipynb).

### Changed
- [**breaking change**] guide_legend()/guide_colorbar() require keyword arguments for 'nrow'/'barwidth' other parameters except 'title'.
- [**breaking change**] `guide_legend()`/`guide_colorbar()` require keyword arguments for `nrow`/`barwidth` other parameters except `title`.

### Fixed
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LegendAssembler(
fun addLayer(
keyFactory: LegendKeyElementFactory,
aesList: List<Aes<*>>,
overrideAesValues: Map<Aes<*>, Any>,
constantByAes: Map<Aes<*>, Any>,
aestheticsDefaults: AestheticsDefaults,
colorByAes: Aes<Color>,
Expand All @@ -50,6 +51,7 @@ class LegendAssembler(
LegendLayer(
keyFactory,
aesList,
overrideAesValues,
constantByAes,
aestheticsDefaults,
scaleMappers,
Expand All @@ -76,29 +78,16 @@ class LegendAssembler(
}
}

val legendBreaks = ArrayList<LegendBreak>()
for (legendBreak in legendBreaksByLabel.values) {
if (legendBreak.isEmpty) {
continue
}
legendBreaks.add(legendBreak)
}


val legendBreaks = legendBreaksByLabel.values.filterNot { it.isEmpty }
if (legendBreaks.isEmpty()) {
return LegendBoxInfo.EMPTY
}

// legend options
val legendOptionsList = ArrayList<LegendOptions>()
for (legendLayer in legendLayers) {
val aesList = legendLayer.aesList
for (aes in aesList) {
if (guideOptionsMap[aes] is LegendOptions) {
legendOptionsList.add(guideOptionsMap[aes] as LegendOptions)
}
}
}
val legendOptionsList =
legendLayers
.map(LegendLayer::aesList)
.flatten()
.mapNotNull { guideOptionsMap[it] as? LegendOptions }

val spec =
createLegendSpec(
Expand All @@ -121,6 +110,7 @@ class LegendAssembler(
private class LegendLayer(
val keyElementFactory: LegendKeyElementFactory,
val aesList: List<Aes<*>>,
overrideAesValues: Map<Aes<*>, Any>,
constantByAes: Map<Aes<*>, Any>,
aestheticsDefaults: AestheticsDefaults,
scaleMappers: Map<Aes<*>, ScaleMapper<*>>,
Expand All @@ -129,13 +119,13 @@ class LegendAssembler(
val isMarginal: Boolean,
ctx: PlotContext
) {

val keyAesthetics: Aesthetics
val keyLabels: List<String>

init {
val aesValuesByLabel =
LinkedHashMap<String, MutableMap<Aes<*>, Any>>()
val labelsValuesByAes: MutableMap<Aes<*>, Pair<List<String>, List<Any?>>> = mutableMapOf()
var maxLabelsSize = 0

for (aes in aesList) {
var scale = ctx.getScale(aes)
if (!scale.hasBreaks()) {
Expand All @@ -147,21 +137,74 @@ class LegendAssembler(
val aesValues = scaleBreaks.transformedValues.map {
scaleMappers.getValue(aes)(it) as Any // Don't expect nulls.
}

val labels = scaleBreaks.labels
for ((label, aesValue) in labels.zip(aesValues)) {
aesValuesByLabel.getOrPut(label) { HashMap() }[aes] = aesValue
}
labelsValuesByAes[aes] = labels to aesValues
maxLabelsSize = maxOf(maxLabelsSize, labels.size)
}

val overrideAesValueLists = createOverrideAesValueLists(overrideAesValues, maxLabelsSize)

val labelValues = applyOverrideAes(overrideAesValueLists, labelsValuesByAes)
keyLabels = labelValues.first

// build 'key' aesthetics
keyAesthetics = mapToAesthetics(
aesValuesByLabel.values,
labelValues.second,
constantByAes,
aestheticsDefaults,
colorByAes,
fillByAes
)
keyLabels = ArrayList(aesValuesByLabel.keys)
}

private fun applyOverrideAes(
overrideAesValueLists: Map<Aes<*>, List<Any?>>,
labelsValuesByAes: MutableMap<Aes<*>, Pair<List<String>, List<Any?>>>
): Pair<List<String>, List<Map<Aes<*>, Any>>> {
val labelsLists = labelsValuesByAes.values.map{ it.first }

labelsLists.forEach { labels ->
overrideAesValueLists.forEach { (aesToOverride, valueList) ->
val currentValues = labelsValuesByAes.getOrPut(aesToOverride) { labels to valueList }.second
val updatedValues = currentValues
.zip(valueList)
.map { (oldValue, newValue) -> newValue ?: oldValue }

labelsValuesByAes[aesToOverride] = labels to updatedValues
}
}

val keyLabels = labelsLists.flatten().distinct()

val mapsByLabel = keyLabels.map { label ->
labelsValuesByAes.mapNotNull { (aes, pair) ->
pair.first.zip(pair.second)
.lastOrNull { it.first == label }
?.second
?.let { aes to it }
}.toMap()
}

return keyLabels to mapsByLabel
}

private fun createOverrideAesValueLists(
overrideAesValues: Map<Aes<*>, Any>,
maxLabelsSize: Int
): Map<Aes<*>, List<Any?>> {
val overrideAesValueLists = overrideAesValues.mapValues { (_, value) ->
val valueList = when (value) {
is List<*> -> value.ifEmpty { listOf(null) }
else -> listOf(value)
}
if (maxLabelsSize <= valueList.size) {
valueList
} else {
valueList + List(maxLabelsSize - valueList.size) { valueList.last() }
}
}
return overrideAesValueLists
}
}

Expand Down Expand Up @@ -254,3 +297,4 @@ class LegendAssembler(
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@

package org.jetbrains.letsPlot.core.plot.builder.assemble

import org.jetbrains.letsPlot.core.plot.base.Aes

class LegendOptions constructor(
val colCount: Int? = null,
val rowCount: Int? = null,
val byRow: Boolean = false,
title: String? = null,
val overrideAesValues: Map<Aes<*>, Any>? = null,
isReverse: Boolean = false
) : GuideOptions(title, isReverse) {
init {
Expand All @@ -27,13 +30,13 @@ class LegendOptions constructor(

override fun withReverse(reverse: Boolean): LegendOptions {
return LegendOptions(
colCount, rowCount, byRow, title, isReverse = reverse
colCount, rowCount, byRow, title, overrideAesValues, isReverse = reverse
)
}

override fun withTitle(title: String?): LegendOptions {
return LegendOptions(
colCount, rowCount, byRow, title = title, isReverse
colCount, rowCount, byRow, title = title, overrideAesValues, isReverse
)
}

Expand All @@ -47,6 +50,7 @@ class LegendOptions constructor(
if (rowCount != other.rowCount) return false
if (byRow != other.byRow) return false
if (title != other.title) return false
if (overrideAesValues != other.overrideAesValues) return false

return true
}
Expand All @@ -56,6 +60,7 @@ class LegendOptions constructor(
result = 31 * result + (rowCount ?: 0)
result = 31 * result + byRow.hashCode()
result = 31 * result + (title?.hashCode() ?: 0)
result = 31 * result + (overrideAesValues?.hashCode() ?: 0)
return result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,19 @@ internal object PlotAssemblerUtil {
val aesListForScaleName = aesListByScaleName.getValue(scaleName)
val legendKeyFactory = layerInfo.legendKeyElementFactory
val aestheticsDefaults = layerInfo.aestheticsDefaults

val allOverrideAesValues =
guideOptionsMap
.filter { (aes, _) -> aes in aesListForScaleName }
.values
.filterIsInstance<LegendOptions>()
.map{it.overrideAesValues.orEmpty()}
.fold(mapOf<Aes<*>, Any>(), { acc, overrideAesValues -> acc + overrideAesValues })

legendAssembler.addLayer(
legendKeyFactory,
aesListForScaleName,
allOverrideAesValues,
layerConstantByAes,
aestheticsDefaults,
layerInfo.colorByAes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ object Option {
const val ROW_COUNT = "nrow"
const val COL_COUNT = "ncol"
const val BY_ROW = "byrow"
const val OVERRIDE_AES = "override_aes"
}

object ColorBar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.jetbrains.letsPlot.core.spec.config

import org.jetbrains.letsPlot.core.plot.base.Aes
import org.jetbrains.letsPlot.core.plot.builder.assemble.ColorBarOptions
import org.jetbrains.letsPlot.core.plot.builder.assemble.GuideOptions
import org.jetbrains.letsPlot.core.plot.builder.assemble.LegendOptions
import org.jetbrains.letsPlot.core.spec.Option
import org.jetbrains.letsPlot.core.spec.Option.Guide.COLOR_BAR
import org.jetbrains.letsPlot.core.spec.Option.Guide.COLOR_BAR_GB
import org.jetbrains.letsPlot.core.spec.Option.Guide.ColorBar.BIN_COUNT
Expand All @@ -16,44 +18,70 @@ import org.jetbrains.letsPlot.core.spec.Option.Guide.ColorBar.WIDTH
import org.jetbrains.letsPlot.core.spec.Option.Guide.LEGEND
import org.jetbrains.letsPlot.core.spec.Option.Guide.Legend.BY_ROW
import org.jetbrains.letsPlot.core.spec.Option.Guide.Legend.COL_COUNT
import org.jetbrains.letsPlot.core.spec.Option.Guide.Legend.OVERRIDE_AES
import org.jetbrains.letsPlot.core.spec.Option.Guide.Legend.ROW_COUNT
import org.jetbrains.letsPlot.core.spec.Option.Guide.NONE
import org.jetbrains.letsPlot.core.spec.Option.Guide.REVERSE
import org.jetbrains.letsPlot.core.spec.Option.Guide.TITLE
import org.jetbrains.letsPlot.core.spec.conversion.AesOptionConversion
import kotlin.math.max

abstract class GuideConfig private constructor(opts: Map<String, Any>) : OptionsAccessor(opts) {

fun createGuideOptions(): GuideOptions {
val options = createGuideOptionsIntern()
fun createGuideOptions(aopConversion: AesOptionConversion): GuideOptions {
val options = createGuideOptionsIntern(aopConversion)
return options
.withTitle(getString(TITLE))
.withReverse(getBoolean(REVERSE))
}

protected abstract fun createGuideOptionsIntern(): GuideOptions
protected abstract fun createGuideOptionsIntern(aopConversion: AesOptionConversion): GuideOptions

private class GuideNoneConfig internal constructor() : GuideConfig(emptyMap()) {

override fun createGuideOptionsIntern(): GuideOptions {
override fun createGuideOptionsIntern(aopConversion: AesOptionConversion): GuideOptions {
return GuideOptions.NONE
}
}

private class LegendConfig internal constructor(opts: Map<String, Any>) : GuideConfig(opts) {

override fun createGuideOptionsIntern(): GuideOptions {
override fun createGuideOptionsIntern(aopConversion: AesOptionConversion): GuideOptions {
return LegendOptions(
colCount = getDouble(COL_COUNT)?.toInt()?.let { max(1, it) },
rowCount = getDouble(ROW_COUNT)?.toInt()?.let { max(1, it) },
byRow = getBoolean(BY_ROW)
byRow = getBoolean(BY_ROW),
overrideAesValues = initValues(
OptionsAccessor(getMap(OVERRIDE_AES)),
aopConversion
)
)
}

private fun initValues(
layerOptions: OptionsAccessor,
aopConversion: AesOptionConversion
): Map<Aes<*>, Any> {
val result = HashMap<Aes<*>, Any>()
Option.Mapping.REAL_AES_OPTION_NAMES
.filter(layerOptions::has)
.associateWith(Option.Mapping::toAes)
.forEach { (option, aes) ->
val optionValue = layerOptions.getSafe(option)
val value = if (optionValue is List<*>) {
optionValue.map { aopConversion.apply(aes, it) }
} else {
aopConversion.apply(aes, optionValue)
} ?: throw IllegalArgumentException("Can't convert to '$option' value: $optionValue")
result[aes] = value
}
return result
}
}

private class ColorBarConfig(opts: Map<String, Any>) : GuideConfig(opts) {

override fun createGuideOptionsIntern(): GuideOptions {
override fun createGuideOptionsIntern(aopConversion: AesOptionConversion): GuideOptions {
return ColorBarOptions(
width = getDouble(WIDTH),
height = getDouble(HEIGHT),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ class PlotConfigFrontend private constructor(
internal val yAxisPosition: AxisPosition

init {
guideOptionsMap = createGuideOptionsMap(this.scaleConfigs) + createGuideOptionsMap(getMap(GUIDES))
guideOptionsMap = createGuideOptionsMap(this.scaleConfigs, aopConversion) +
createGuideOptionsMap(getMap(GUIDES), aopConversion)

xAxisPosition = scaleProviderByAes.getValue(Aes.X).axisPosition
yAxisPosition = scaleProviderByAes.getValue(Aes.Y).axisPosition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,26 @@ import org.jetbrains.letsPlot.core.spec.config.GuideConfig
import org.jetbrains.letsPlot.core.spec.config.OptionsAccessor.Companion.over
import org.jetbrains.letsPlot.core.spec.config.PlotConfigTransforms
import org.jetbrains.letsPlot.core.spec.config.ScaleConfig
import org.jetbrains.letsPlot.core.spec.conversion.AesOptionConversion
import org.jetbrains.letsPlot.core.spec.front.tiles.PlotTilesConfig

object PlotConfigFrontendUtil {
internal fun createGuideOptionsMap(scaleConfigs: List<ScaleConfig<*>>): Map<Aes<*>, GuideOptions> {
internal fun createGuideOptionsMap(scaleConfigs: List<ScaleConfig<*>>, aopConversion: AesOptionConversion): Map<Aes<*>, GuideOptions> {
val guideOptionsByAes = HashMap<Aes<*>, GuideOptions>()
for (scaleConfig in scaleConfigs) {
if (scaleConfig.hasGuideOptions()) {
val guideOptions = scaleConfig.getGuideOptions().createGuideOptions()
val guideOptions = scaleConfig.getGuideOptions().createGuideOptions(aopConversion)
guideOptionsByAes[scaleConfig.aes] = guideOptions
}
}
return guideOptionsByAes
}

internal fun createGuideOptionsMap(guideOptionsList: Map<String, Any>): Map<Aes<*>, GuideOptions> {
internal fun createGuideOptionsMap(guideOptionsList: Map<String, Any>, aopConversion: AesOptionConversion): Map<Aes<*>, GuideOptions> {
val guideOptionsByAes = HashMap<Aes<*>, GuideOptions>()
for ((key, value) in guideOptionsList) {
val aes = Option.Mapping.toAes(key)
guideOptionsByAes[aes] = GuideConfig.create(value).createGuideOptions()
guideOptionsByAes[aes] = GuideConfig.create(value).createGuideOptions(aopConversion)
}
return guideOptionsByAes
}
Expand Down
Loading