Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean clim patch #920

Merged
merged 19 commits into from
Apr 16, 2023
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def parallel_coordinate_plot(
metric_names,
model_names,
models_to_highlight=list(),
models_to_highlight_colors=None,
fig=None,
ax=None,
figsize=(15, 5),
Expand All @@ -23,7 +24,6 @@ def parallel_coordinate_plot(
violin_colors=("lightgrey", "pink"),
title=None,
identify_all_models=True,
xtick_labels=None,
xtick_labelsize=None,
ytick_labelsize=None,
colormap="viridis",
Expand All @@ -37,6 +37,8 @@ def parallel_coordinate_plot(
comparing_models=None,
fill_between_lines=False,
fill_between_lines_colors=("green", "red"),
median_centered=False,
median_line=False,
):
"""
Parameters
Expand All @@ -45,6 +47,7 @@ def parallel_coordinate_plot(
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
- `model_names`: list, name of models for markers/lines (axis=0)
- `models_to_highlight`: list, default=None, List of models to highlight as lines
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
- `fig`: `matplotlib.figure` instance to which the parallel coordinate plot is plotted.
If not provided, use current axes or create a new one. Optional.
- `ax`: `matplotlib.axes.Axes` instance to which the parallel coordinate plot is plotted.
Expand All @@ -69,14 +72,18 @@ def parallel_coordinate_plot(
- `comparing_models`: tuple or list containing two strings for models to compare with colors filled between the two lines.
- `fill_between_lines`: bool, default=False, fill color between lines for models in comparing_models
- `fill_between_lines_colors`: tuple or list containing two strings for colors filled between the two lines. Default=('green', 'red')
- `median_centered`: bool, default=False, adjust range of vertical axis to set center of vertical axis as median
- `median_line`: bool, default=False, show median as line

Return
------
- `fig`: matplotlib component for figure
- `ax`: matplotlib component for axis

Author: Jiwoo Lee @ LLNL (2021. 7)
Last update: 2022. 9
Update history:
2022-09 violin plots added
2023-03 median centered option added
Inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
"""
params = {
Expand All @@ -92,13 +99,14 @@ def parallel_coordinate_plot(
_quick_qc(data, model_names, metric_names, model_names2=model_names2)

# Transform data for plotting
zs, N, ymins, ymaxs, df_stacked, df2_stacked = _data_transform(
zs, zs_meds, N, ymins, ymaxs, df_stacked, df2_stacked = _data_transform(
data,
metric_names,
model_names,
model_names2=model_names2,
group1_name=group1_name,
group2_name=group2_name,
median_centered=median_centered,
)

# Prepare plot
Expand All @@ -123,8 +131,6 @@ def parallel_coordinate_plot(

if fig is None and ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
ax = ax

axes = [ax] + [ax.twinx() for i in range(N - 1)]

Expand Down Expand Up @@ -168,7 +174,11 @@ def parallel_coordinate_plot(
showextrema=False,
)
for pc in violin["bodies"]:
pc.set_facecolor(violin_colors[0])
if isinstance(violin_colors, tuple) or isinstance(violin_colors, list):
violin_color = violin_colors[0]
else:
violin_color = violin_colors
pc.set_facecolor(violin_color)
pc.set_edgecolor("None")
pc.set_alpha(0.8)
else:
Expand All @@ -194,10 +204,17 @@ def parallel_coordinate_plot(
marker_types = ["o", "s", "*", "^", "X", "D", "p"]
markers = list(flatten([[marker] * len(colors) for marker in marker_types]))
colors *= len(marker_types)
mh_index = 0
for j, model in enumerate(model_names):
# to just draw straight lines between the axes:
if model in models_to_highlight:
ax.plot(range(N), zs[j, :], "-", c=colors[j], label=model, lw=3)

if models_to_highlight_colors is not None:
color = models_to_highlight_colors[mh_index]
else:
color = colors[j]
ax.plot(range(N), zs[j, :], "-", c=color, label=model, lw=3)
mh_index += 1
else:
if identify_all_models:
ax.plot(
Expand All @@ -208,6 +225,9 @@ def parallel_coordinate_plot(
label=model,
clip_on=False,
)

if median_line:
ax.plot(range(N), zs_meds, "-", c="k", label="median", lw=1)

# Fill between lines
if fill_between_lines and (comparing_models is not None):
Expand All @@ -226,6 +246,7 @@ def parallel_coordinate_plot(
where=y2 >= y1,
facecolor=fill_between_lines_colors[0],
interpolate=True,
alpha=0.5,
)
ax.fill_between(
x,
Expand All @@ -234,6 +255,7 @@ def parallel_coordinate_plot(
where=y2 <= y1,
facecolor=fill_between_lines_colors[1],
interpolate=True,
alpha=0.5,
)

ax.set_xlim(-0.5, N - 0.5)
Expand Down Expand Up @@ -287,12 +309,19 @@ def _data_transform(
model_names2=None,
group1_name="group1",
group2_name="group2",
median_centered=False,
):
# Data to plot
ys = data # stacked y-axis values
N = ys.shape[1] # number of vertical axis (i.e., =len(metric_names))
ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value)
ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value)
ymeds = np.nanmedian(ys, axis=0) # median
if median_centered:
for i in range(0, N):
max_distance_from_median = max(abs(ymaxs[i] - ymeds[i]), abs(ymeds[i] - ymins[i]))
ymaxs[i] = ymeds[i] + max_distance_from_median
ymins[i] = ymeds[i] - max_distance_from_median
dys = ymaxs - ymins
ymins -= dys * 0.05 # add 5% padding below and above
ymaxs += dys * 0.05
Expand All @@ -302,6 +331,8 @@ def _data_transform(
zs = np.zeros_like(ys)
zs[:, 0] = ys[:, 0]
zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0]

zs_meds = (ymeds[:] - ymins[:]) / dys[:] * dys[0] + ymins[0]

if model_names2 is not None:
print("Models in the second group:", model_names2)
Expand All @@ -324,7 +355,7 @@ def _data_transform(
group2_name=group2_name,
)

return zs, N, ymins, ymaxs, df_stacked, df2_stacked
return zs, zs_meds, N, ymins, ymaxs, df_stacked, df2_stacked


def _to_pd_dataframe(
Expand Down