Skip to content

Commit

Permalink
Jitter reproducibility in geom_jitter() (#920)
Browse files Browse the repository at this point in the history
* Add random seed into position_jitter()

* Adding note in future_changes.md

* Set null as default seed value

* Support seed in jitterdodge.

* Documented seed for geom_jitter() function

* Changing parameter name to align with the name in ggplot2.
  • Loading branch information
RYangazov committed Nov 9, 2023
1 parent 50d3936 commit 11e4da5
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 26 deletions.
4 changes: 2 additions & 2 deletions future_changes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
### Changed

### Fixed

- geom_livemap: fix missing styles (e.g. road outline on high zooms) [[#926](https://github.com/JetBrains/lets-plot/issues/926)].
- geom_livemap: fix missing styles (e.g. road outline on high zooms) [[#926](https://github.com/JetBrains/lets-plot/issues/926)].
- Jitter reproducibility in geom_jitter, position_jitter, position_jitterdodge [[#911](https://github.com/JetBrains/lets-plot/issues/911)].
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@ import org.jetbrains.letsPlot.core.plot.base.DataPointAesthetics
import org.jetbrains.letsPlot.core.plot.base.GeomContext
import org.jetbrains.letsPlot.core.plot.base.PositionAdjustment

class JitterDodgePos(aesthetics: Aesthetics, groupCount: Int, width: Double?, jitterWidth: Double?, jitterHeight: Double?) :
class JitterDodgePos(
aesthetics: Aesthetics,
groupCount: Int,
width: Double?,
jitterWidth: Double?,
jitterHeight: Double?,
seed: Long? = null
) :
PositionAdjustment {
private val myJitterPosHelper: PositionAdjustment
private val myDodgePosHelper: PositionAdjustment

init {
myJitterPosHelper = JitterPos(jitterWidth, jitterHeight)
myJitterPosHelper = JitterPos(jitterWidth, jitterHeight, seed)
myDodgePosHelper = DodgePos(aesthetics, groupCount, width)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,21 @@ import org.jetbrains.letsPlot.core.plot.base.GeomContext
import org.jetbrains.letsPlot.core.plot.base.PositionAdjustment
import kotlin.random.Random

internal class JitterPos(width: Double?, height: Double?) : PositionAdjustment {
internal class JitterPos(width: Double?, height: Double?, seed: Long? = null) : PositionAdjustment {

//uniform distribution
private val myWidth: Double
private val myHeight: Double
private val random: Random = seed?.let { Random(seed) } ?: Random.Default

init {
myWidth = width ?: DEF_JITTER_WIDTH
myHeight = height ?: DEF_JITTER_HEIGHT
}

override fun translate(v: DoubleVector, p: DataPointAesthetics, ctx: GeomContext): DoubleVector {
val x = (2 * Random.nextDouble() - 1) * myWidth * ctx.getResolution(Aes.X)
val y = (2 * Random.nextDouble() - 1) * myHeight * ctx.getResolution(Aes.Y)
val x = (2 * random.nextDouble() - 1) * myWidth * ctx.getResolution(Aes.X)
val y = (2 * random.nextDouble() - 1) * myHeight * ctx.getResolution(Aes.Y)
return v.add(DoubleVector(x, y))
}

Expand All @@ -34,7 +35,6 @@ internal class JitterPos(width: Double?, height: Double?) : PositionAdjustment {
}

companion object {

val DEF_JITTER_WIDTH = 0.4
val DEF_JITTER_HEIGHT = 0.4
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ object PositionAdjustments {
return FillPos(aesthetics, vjust, stackingMode)
}

fun jitter(width: Double?, height: Double?): PositionAdjustment {
return JitterPos(width, height)
fun jitter(width: Double?, height: Double?, seed: Long?): PositionAdjustment {
return JitterPos(width, height, seed)
}

fun nudge(width: Double?, height: Double?): PositionAdjustment {
Expand All @@ -64,9 +64,10 @@ object PositionAdjustments {
groupCount: Int,
width: Double?,
jitterWidth: Double?,
jitterHeight: Double?
jitterHeight: Double?,
seed: Long?
): PositionAdjustment {
return JitterDodgePos(aesthetics, groupCount, width, jitterWidth, jitterHeight)
return JitterDodgePos(aesthetics, groupCount, width, jitterWidth, jitterHeight, seed)
}

enum class Meta(private val handlesGroups: Boolean) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ abstract class PosProvider {
}
}

fun jitter(width: Double?, height: Double?): PosProvider {
fun jitter(width: Double?, height: Double?, seed: Long?): PosProvider {
return object : PosProvider() {
override fun createPos(ctx: PosProviderContext): PositionAdjustment {
return PositionAdjustments.jitter(width, height)
return PositionAdjustments.jitter(width, height, seed)
}

override fun handlesGroups(): Boolean {
Expand All @@ -115,12 +115,19 @@ abstract class PosProvider {
}
*/

fun jitterDodge(width: Double?, jitterWidth: Double?, jitterHeight: Double?): PosProvider {
fun jitterDodge(width: Double?, jitterWidth: Double?, jitterHeight: Double?, seed: Long?): PosProvider {
return object : PosProvider() {
override fun createPos(ctx: PosProviderContext): PositionAdjustment {
val aesthetics = ctx.aesthetics
val groupCount = ctx.groupCount
return PositionAdjustments.jitterDodge(aesthetics, groupCount, width, jitterWidth, jitterHeight)
return PositionAdjustments.jitterDodge(
aesthetics,
groupCount,
width,
jitterWidth,
jitterHeight,
seed
)
}

override fun handlesGroups(): Boolean {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class GeomProto(val geomKind: GeomKind) {
Meta.NAME to PosProto.JITTER,
Pos.Jitter.WIDTH to layerOptions.getDouble(Geom.Jitter.WIDTH),
Pos.Jitter.HEIGHT to layerOptions.getDouble(Geom.Jitter.HEIGHT),
Pos.Jitter.SEED to layerOptions.getLong(Geom.Jitter.SEED)
)

Y_DOT_PLOT -> if (layerOptions.hasOwn(Geom.YDotplot.STACKGROUPS) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ object Option {
object Jitter {
const val WIDTH = "width"
const val HEIGHT = "height"
const val SEED = "seed"
}

object Step {
Expand Down Expand Up @@ -467,6 +468,7 @@ object Option {
object Jitter {
const val WIDTH = "width"
const val HEIGHT = "height"
const val SEED = "seed"
}

object Nudge {
Expand All @@ -478,6 +480,7 @@ object Option {
const val DODGE_WIDTH = "dodge_width"
const val JITTER_WIDTH = "jitter_width"
const val JITTER_HEIGHT = "jitter_height"
const val SEED = "seed"
}

object Stack {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ internal object PosProto {
FILL -> configureFillPosition(opts)
JITTER -> PosProvider.jitter(
opts.getDouble(Pos.Jitter.WIDTH),
opts.getDouble(Pos.Jitter.HEIGHT)
opts.getDouble(Pos.Jitter.HEIGHT),
opts.getLong(Pos.Jitter.SEED)
)

NUDGE -> PosProvider.nudge(
Expand All @@ -47,7 +48,8 @@ internal object PosProto {
JITTER_DODGE -> PosProvider.jitterDodge(
opts.getDouble(Pos.JitterDodge.DODGE_WIDTH),
opts.getDouble(Pos.JitterDodge.JITTER_WIDTH),
opts.getDouble(Pos.JitterDodge.JITTER_HEIGHT)
opts.getDouble(Pos.JitterDodge.JITTER_HEIGHT),
opts.getLong(Pos.JitterDodge.SEED)
)

else -> throw IllegalArgumentException("Unknown position adjustments name: '$posName'")
Expand Down
9 changes: 7 additions & 2 deletions python-package/lets_plot/plot/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4595,9 +4595,9 @@ def geom_density2df(mapping=None, *, data=None, stat=None, position=None, show_l


def geom_jitter(mapping=None, *, data=None, stat=None, position=None, show_legend=None, sampling=None, tooltips=None,
width=None,
height=None,
width=None, height=None,
color_by=None, fill_by=None,
seed=None,
**other_args):
"""
Display jittered points, especially for discrete plots or dense plots.
Expand Down Expand Up @@ -4639,6 +4639,9 @@ def geom_jitter(mapping=None, *, data=None, stat=None, position=None, show_legen
Define the color aesthetic for the geometry.
fill_by : {'fill', 'color', 'paint_a', 'paint_b', 'paint_c'}, default='fill'
Define the fill aesthetic for the geometry.
seed : int
A random seed to make the jitter reproducible.
If None (the default value), the seed is initialised with a random value.
other_args
Other arguments passed on to the layer.
These are often aesthetics settings used to set an aesthetic to a fixed value,
Expand Down Expand Up @@ -4699,6 +4702,7 @@ def geom_jitter(mapping=None, *, data=None, stat=None, position=None, show_legen
ggplot({'x': x, 'y': y}, aes(x='x', y='y')) + \\
geom_jitter(aes(color='x', size='y'), \\
sampling=sampling_random(n=600, seed=60), \\
seed=37, + \\
show_legend=False, width=.25) + \\
scale_color_grey(start=.75, end=0) + \\
scale_size(range=[1, 3])
Expand All @@ -4714,6 +4718,7 @@ def geom_jitter(mapping=None, *, data=None, stat=None, position=None, show_legen
tooltips=tooltips,
width=width, height=height,
color_by=color_by, fill_by=fill_by,
seed=seed,
**other_args)


Expand Down
19 changes: 13 additions & 6 deletions python-package/lets_plot/plot/pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def position_dodgev(height=None):
return _pos('dodgev', height=height)


def position_jitter(width=None, height=None):
def position_jitter(width=None, height=None, seed=None):
"""
Adjust position by assigning random noise to points. Better for discrete values.
Expand All @@ -110,6 +110,9 @@ def position_jitter(width=None, height=None):
Jittering height.
The value of height is relative and typically ranges between 0 and 0.5.
Values that are greater than 0.5 lead to overlapping of the points.
seed : int
A random seed to make the jitter reproducible.
If None (the default value), the seed is initialised with a random value.
Returns
-------
Expand Down Expand Up @@ -137,10 +140,10 @@ def position_jitter(width=None, height=None):
ggplot({'x': x, 'y': y, 'c': c}, aes('x', 'y')) + \\
geom_point(aes(fill='c'), show_legend=False, \\
size=8, alpha=.5, shape=21, color='black', \\
position=position_jitter(width=.2, height=.2))
position=position_jitter(width=.2, height=.2, seed=42))
"""
return _pos('jitter', width=width, height=height)
return _pos('jitter', width=width, height=height, seed=seed)


def position_nudge(x=None, y=None):
Expand Down Expand Up @@ -185,7 +188,7 @@ def position_nudge(x=None, y=None):
return _pos('nudge', x=x, y=y)


def position_jitterdodge(dodge_width=None, jitter_width=None, jitter_height=None):
def position_jitterdodge(dodge_width=None, jitter_width=None, jitter_height=None, seed=None):
"""
This is primarily used for aligning points generated through `geom_point()`
with dodged boxplots (e.g., a `geom_boxplot()` with a fill aesthetic supplied).
Expand All @@ -204,6 +207,9 @@ def position_jitterdodge(dodge_width=None, jitter_width=None, jitter_height=None
Jittering height.
The value of `jitter_height` is relative and typically ranges between 0 and 0.5.
Values that are greater than 0.5 lead to overlapping of the points.
seed : int
A random seed to make the jitter reproducible.
If None (the default value), the seed is initialised with a random value.
Returns
-------
Expand Down Expand Up @@ -232,10 +238,11 @@ def position_jitterdodge(dodge_width=None, jitter_width=None, jitter_height=None
stat='boxplot') + \\
geom_point(aes(x='c', y='x', color='c'), \\
size=4, shape=21, fill='white',
position=position_jitterdodge())
position=position_jitterdodge(seed=42))
"""
return _pos('jitterdodge', dodge_width=dodge_width, jitter_width=jitter_width, jitter_height=jitter_height)
return _pos('jitterdodge', dodge_width=dodge_width, jitter_width=jitter_width, jitter_height=jitter_height,
seed=seed)


def position_stack(vjust=None, mode=None):
Expand Down

0 comments on commit 11e4da5

Please sign in to comment.