Skip to content

Commit

Permalink
Add JointGrid and FacetGrid reference lines (mwaskom#2620)
Browse files Browse the repository at this point in the history
* Add method to JointGrid for plotting reference lines.

* Tweak docstring

* Require one of joint/marginal to be True.

* Switch from orient to x/y.

* Add additional input validation.

* Allow both x and y; remove ValueErrors.

* Add color and linestyle params.

* Add JointGrid.refline() tests.

* Add JointGrid.refline() example.

* Update JointGrid.refline() docstring

Co-authored-by: Michael Waskom <[email protected]>

* Add FacetGrid.refline()

* Switch examples to use FacetGrid.refline()

* Use :meth: for method reference.

* Add FacetGrid.refline() test and example.

* Add refline to release notes.

* Address PR comments

Co-authored-by: Michael Waskom <[email protected]>
  • Loading branch information
stefmolin and mwaskom committed Jul 24, 2021
1 parent ad0240a commit 70cb3d9
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 3 deletions.
18 changes: 18 additions & 0 deletions doc/docstrings/FacetGrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@
"g.map(sns.histplot, \"total_bill\")"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"To add horizontal or vertical reference lines on every facet, use :meth:`FacetGrid.refline`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"g = sns.FacetGrid(tips, col=\"time\", margin_titles=True)\n",
"g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n",
"g.refline(y=tips[\"tip\"].median())"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
18 changes: 18 additions & 0 deletions doc/docstrings/JointGrid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@
"g.plot(sns.scatterplot, sns.histplot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Horizontal and/or vertical reference lines can be added to the joint and/or marginal axes using :meth:`JointGrid.refline`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n",
"g.plot(sns.scatterplot, sns.histplot)\n",
"g.refline(x=45, y=16)"
]
},
{
"cell_type": "raw",
"metadata": {},
Expand Down
2 changes: 2 additions & 0 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ A paper describing seaborn was published in the `Journal of Open Source Software

- |Enhancement| |Fix| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).

- |Feature| Added a ``refline`` method to :class:`FacetGrid` and :class:`JointGrid` for including horizontal and/or vertical reference lines using :meth:`matplotlib.axes.Axes.axhline`/:meth:`matplotlib.axes.Axes.axvline` (:pr:`2620`).

- |API| In :func:`lmplot`, the `sharex`, `sharey`, and `legend_out` parameters have been deprecated from the function signature, but they can be passed using the new `facet_kws` parameter (:pr:`2576`).

- |Fix| In :func:`lineplot, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`).
Expand Down
6 changes: 4 additions & 2 deletions examples/kde_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
bw_adjust=.5, clip_on=False,
fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw_adjust=.5)
g.map(plt.axhline, y=0, lw=2, clip_on=False)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
Expand All @@ -44,5 +46,5 @@ def label(x, color, label):

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[])
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)
2 changes: 1 addition & 1 deletion examples/many_facets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
col_wrap=4, height=1.5)

# Draw a horizontal line to show the starting point
grid.map(plt.axhline, y=0, ls=":", c=".5")
grid.refline(y=0, linestyle=":")

# Draw a line plot to show the trajectory of each random walk
grid.map(plt.plot, "step", "position", marker="o")
Expand Down
76 changes: 76 additions & 0 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,38 @@ def set_titles(self, template=None, row_template=None, col_template=None,
self.axes.flat[i].set_title(title, **kwargs)
return self

def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
"""Add a reference line(s) to each facet.
Parameters
----------
x, y : numeric
Value(s) to draw the line(s) at.
color : :mod:`matplotlib color <matplotlib.colors>`
Specifies the color of the reference line(s). Pass ``color=None`` to
use ``hue`` mapping.
linestyle : str
Specifies the style of the reference line(s).
line_kws : key, value mappings
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
is not None.
Returns
-------
:class:`FacetGrid` instance
Returns ``self`` for easy method chaining.
"""
line_kws['color'] = color
line_kws['linestyle'] = linestyle

if x is not None:
self.map(plt.axvline, x=x, **line_kws)

if y is not None:
self.map(plt.axhline, y=y, **line_kws)

# ------ Properties that are part of the public API and documented by Sphinx

@property
Expand Down Expand Up @@ -1797,6 +1829,50 @@ def plot_marginals(self, func, **kwargs):

return self

def refline(
self, *, x=None, y=None, joint=True, marginal=True,
color='.5', linestyle='--', **line_kws
):
"""Add a reference line(s) to joint and/or marginal axes.
Parameters
----------
x, y : numeric
Value(s) to draw the line(s) at.
joint, marginal : bools
Whether to add the reference line(s) to the joint/marginal axes.
color : :mod:`matplotlib color <matplotlib.colors>`
Specifies the color of the reference line(s).
linestyle : str
Specifies the style of the reference line(s).
line_kws : key, value mappings
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
is not None.
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
line_kws['color'] = color
line_kws['linestyle'] = linestyle

if x is not None:
if joint:
self.ax_joint.axvline(x, **line_kws)
if marginal:
self.ax_marg_x.axvline(x, **line_kws)

if y is not None:
if joint:
self.ax_joint.axhline(y, **line_kws)
if marginal:
self.ax_marg_y.axhline(y, **line_kws)

return self

def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
"""Set axis labels on the bivariate axes.
Expand Down
66 changes: 66 additions & 0 deletions seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,30 @@ def test_categorical_warning(self):
with pytest.warns(UserWarning):
g.map(pointplot, "b", "x")

def test_refline(self):

g = ag.FacetGrid(self.df, row="a", col="b")
g.refline()
for ax in g.axes.ravel():
assert not ax.lines

refx = refy = 0.5
hline = np.array([[0, refy], [1, refy]])
vline = np.array([[refx, 0], [refx, 1]])
g.refline(x=refx, y=refy)
for ax in g.axes.ravel():
assert ax.lines[0].get_color() == '.5'
assert ax.lines[0].get_linestyle() == '--'
assert len(ax.lines) == 2
npt.assert_array_equal(ax.lines[0].get_xydata(), vline)
npt.assert_array_equal(ax.lines[1].get_xydata(), hline)

color, linestyle = 'red', '-'
g.refline(x=refx, color=color, linestyle=linestyle)
npt.assert_array_equal(g.axes[0, 0].lines[-1].get_xydata(), vline)
assert g.axes[0, 0].lines[-1].get_color() == color
assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle


class TestPairGrid:

Expand Down Expand Up @@ -1542,6 +1566,48 @@ def test_hue(self, long_df, as_vector):
assert_plots_equal(g.ax_marg_x, g2.ax_marg_x, labels=False)
assert_plots_equal(g.ax_marg_y, g2.ax_marg_y, labels=False)

def test_refline(self):

g = ag.JointGrid(x="x", y="y", data=self.data)
g.plot(scatterplot, histplot)
g.refline()
assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines

refx = refy = 0.5
hline = np.array([[0, refy], [1, refy]])
vline = np.array([[refx, 0], [refx, 1]])
g.refline(x=refx, y=refy, joint=False, marginal=False)
assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines

g.refline(x=refx, y=refy)
assert g.ax_joint.lines[0].get_color() == '.5'
assert g.ax_joint.lines[0].get_linestyle() == '--'
assert len(g.ax_joint.lines) == 2
assert len(g.ax_marg_x.lines) == 1
assert len(g.ax_marg_y.lines) == 1
npt.assert_array_equal(g.ax_joint.lines[0].get_xydata(), vline)
npt.assert_array_equal(g.ax_joint.lines[1].get_xydata(), hline)
npt.assert_array_equal(g.ax_marg_x.lines[0].get_xydata(), vline)
npt.assert_array_equal(g.ax_marg_y.lines[0].get_xydata(), hline)

color, linestyle = 'red', '-'
g.refline(x=refx, marginal=False, color=color, linestyle=linestyle)
npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), vline)
assert g.ax_joint.lines[-1].get_color() == color
assert g.ax_joint.lines[-1].get_linestyle() == linestyle
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)

g.refline(x=refx, joint=False)
npt.assert_array_equal(g.ax_marg_x.lines[-1].get_xydata(), vline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines) + 1

g.refline(y=refy, joint=False)
npt.assert_array_equal(g.ax_marg_y.lines[-1].get_xydata(), hline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)

g.refline(y=refy, marginal=False)
npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), hline)
assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)

class TestJointPlot:

Expand Down

0 comments on commit 70cb3d9

Please sign in to comment.