Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel coordinate plot advance #996

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ def parallel_coordinate_plot(
metric_names,
model_names,
models_to_highlight=list(),
models_to_highlight_by_line=True,
models_to_highlight_colors=None,
models_to_highlight_labels=None,
models_to_highlight_markers=["s", "o", "^", "*"],
models_to_highlight_markers_size=10,
fig=None,
ax=None,
figsize=(15, 5),
Expand All @@ -37,7 +40,10 @@ def parallel_coordinate_plot(
group2_name="group2",
comparing_models=None,
fill_between_lines=False,
fill_between_lines_colors=("green", "red"),
fill_between_lines_colors=("red", "green"),
arrow_between_lines=False,
arrow_between_lines_colors=("red", "green"),
arrow_alpha=1,
vertical_center=None,
vertical_center_line=False,
vertical_center_line_label=None,
Expand All @@ -50,9 +56,12 @@ def parallel_coordinate_plot(
- `data`: 2-d numpy array for metrics
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
- `model_names`: list, name of models for markers/lines (axis=0)
- `models_to_highlight`: list, default=None, List of models to highlight as lines
- `models_to_highlight`: list, default=None, List of models to highlight as lines or marker
- `models_to_highlight_by_line`: bool, default=True, highlight as lines. If False, as marker
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
- `models_to_highlight_labels`: list, default=None, List of string labels for models to highlight as lines
- `models_to_highlight_markers`: list, matplotlib markers for models to highlight if as marker
- `models_to_highlight_markers_size`: float, size of matplotlib markers for models to highlight if as marker
- `fig`: `matplotlib.figure` instance to which the parallel coordinate plot is plotted.
If not provided, use current axes or create a new one. Optional.
- `ax`: `matplotlib.axes.Axes` instance to which the parallel coordinate plot is plotted.
Expand All @@ -76,7 +85,10 @@ def parallel_coordinate_plot(
- `group2_name`: string, needed for violin plot legend if splited to two groups, for the 2nd group. Default is 'group2'.
- `comparing_models`: tuple or list containing two strings for models to compare with colors filled between the two lines.
- `fill_between_lines`: bool, default=False, fill color between lines for models in comparing_models
- `fill_between_lines_colors`: tuple or list containing two strings for colors filled between the two lines. Default=('green', 'red')
- `fill_between_lines_colors`: tuple or list containing two strings of colors for filled between the two lines. Default=('red', 'green')
- `arrow_between_lines`: bool, default=False, place arrows between two lines for models in comparing_models
- `arrow_between_lines_colors`: tuple or list containing two strings of colors for arrow between the two lines. Default=('red', 'green')
- `arrow_alpha`: float, default=1, transparency of arrow (faction between 0 to 1)
- `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number
- `vertical_center_line`: bool, default=False, show median as line
- `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off"
Expand Down Expand Up @@ -231,7 +243,18 @@ def parallel_coordinate_plot(
else:
label = model

ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3)
if models_to_highlight_by_line:
ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3)
else:
ax.plot(
range(N),
zs[j, :],
models_to_highlight_markers[mh_index],
c=color,
label=label,
markersize=models_to_highlight_markers_size,
)

mh_index += 1
else:
if identify_all_models:
Expand All @@ -251,8 +274,8 @@ def parallel_coordinate_plot(
vertical_center_line_label = None
ax.plot(range(N), zs_middle, "-", c="k", label=vertical_center_line_label, lw=1)

# Fill between lines
if fill_between_lines and (comparing_models is not None):
# Compare two models
if comparing_models is not None:
if isinstance(comparing_models, tuple) or (
isinstance(comparing_models, list) and len(comparing_models) == 2
):
Expand All @@ -261,24 +284,49 @@ def parallel_coordinate_plot(
m2 = model_names.index(comparing_models[1])
y1 = zs[m1, :]
y2 = zs[m2, :]
ax.fill_between(
x,
y1,
y2,
where=y2 >= y1,
facecolor=fill_between_lines_colors[0],
interpolate=True,
alpha=0.5,
)
ax.fill_between(
x,
y1,
y2,
where=y2 <= y1,
facecolor=fill_between_lines_colors[1],
interpolate=True,
alpha=0.5,
)

# Fill between lines
if fill_between_lines:
ax.fill_between(
x,
y1,
y2,
where=(y2 > y1),
facecolor=fill_between_lines_colors[0],
interpolate=False,
alpha=0.5,
)
ax.fill_between(
x,
y1,
y2,
where=(y2 < y1),
facecolor=fill_between_lines_colors[1],
interpolate=False,
alpha=0.5,
)

if arrow_between_lines:
# Add vertical arrows
for xi, yi1, yi2 in zip(x, y1, y2):
if yi2 > yi1:
arrow_color = arrow_between_lines_colors[0]
elif yi2 < yi1:
arrow_color = arrow_between_lines_colors[1]
else:
arrow_color = None
arrow_length = yi2 - yi1
ax.arrow(
xi,
yi1,
0,
arrow_length,
color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha,
width=0.05,
head_width=0.15,
)

ax.set_xlim(-0.5, N - 0.5)
ax.set_xticks(range(N))
Expand Down