diff --git a/pcmdi_metrics/graphics/parallel_coordinate_plot/parallel_coordinate_plot_lib.py b/pcmdi_metrics/graphics/parallel_coordinate_plot/parallel_coordinate_plot_lib.py index 490ba7214..8c37181a3 100644 --- a/pcmdi_metrics/graphics/parallel_coordinate_plot/parallel_coordinate_plot_lib.py +++ b/pcmdi_metrics/graphics/parallel_coordinate_plot/parallel_coordinate_plot_lib.py @@ -15,8 +15,11 @@ def parallel_coordinate_plot( metric_names, model_names, models_to_highlight=list(), + models_to_highlight_by_line=True, models_to_highlight_colors=None, models_to_highlight_labels=None, + models_to_highlight_markers=["s", "o", "^", "*"], + models_to_highlight_markers_size=10, fig=None, ax=None, figsize=(15, 5), @@ -37,7 +40,10 @@ def parallel_coordinate_plot( group2_name="group2", comparing_models=None, fill_between_lines=False, - fill_between_lines_colors=("green", "red"), + fill_between_lines_colors=("red", "green"), + arrow_between_lines=False, + arrow_between_lines_colors=("red", "green"), + arrow_alpha=1, vertical_center=None, vertical_center_line=False, vertical_center_line_label=None, @@ -50,9 +56,12 @@ def parallel_coordinate_plot( - `data`: 2-d numpy array for metrics - `metric_names`: list, names of metrics for individual vertical axes (axis=1) - `model_names`: list, name of models for markers/lines (axis=0) - - `models_to_highlight`: list, default=None, List of models to highlight as lines + - `models_to_highlight`: list, default=None, List of models to highlight as lines or marker + - `models_to_highlight_by_line`: bool, default=True, highlight as lines. If False, as marker - `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines - `models_to_highlight_labels`: list, default=None, List of string labels for models to highlight as lines + - `models_to_highlight_markers`: list, matplotlib markers for models to highlight if as marker + - `models_to_highlight_markers_size`: float, size of matplotlib markers for models to highlight if as marker - `fig`: `matplotlib.figure` instance to which the parallel coordinate plot is plotted. If not provided, use current axes or create a new one. Optional. - `ax`: `matplotlib.axes.Axes` instance to which the parallel coordinate plot is plotted. @@ -76,7 +85,10 @@ def parallel_coordinate_plot( - `group2_name`: string, needed for violin plot legend if splited to two groups, for the 2nd group. Default is 'group2'. - `comparing_models`: tuple or list containing two strings for models to compare with colors filled between the two lines. - `fill_between_lines`: bool, default=False, fill color between lines for models in comparing_models - - `fill_between_lines_colors`: tuple or list containing two strings for colors filled between the two lines. Default=('green', 'red') + - `fill_between_lines_colors`: tuple or list containing two strings of colors for filled between the two lines. Default=('red', 'green') + - `arrow_between_lines`: bool, default=False, place arrows between two lines for models in comparing_models + - `arrow_between_lines_colors`: tuple or list containing two strings of colors for arrow between the two lines. Default=('red', 'green') + - `arrow_alpha`: float, default=1, transparency of arrow (faction between 0 to 1) - `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number - `vertical_center_line`: bool, default=False, show median as line - `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off" @@ -231,7 +243,18 @@ def parallel_coordinate_plot( else: label = model - ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3) + if models_to_highlight_by_line: + ax.plot(range(N), zs[j, :], "-", c=color, label=label, lw=3) + else: + ax.plot( + range(N), + zs[j, :], + models_to_highlight_markers[mh_index], + c=color, + label=label, + markersize=models_to_highlight_markers_size, + ) + mh_index += 1 else: if identify_all_models: @@ -251,8 +274,8 @@ def parallel_coordinate_plot( vertical_center_line_label = None ax.plot(range(N), zs_middle, "-", c="k", label=vertical_center_line_label, lw=1) - # Fill between lines - if fill_between_lines and (comparing_models is not None): + # Compare two models + if comparing_models is not None: if isinstance(comparing_models, tuple) or ( isinstance(comparing_models, list) and len(comparing_models) == 2 ): @@ -261,24 +284,49 @@ def parallel_coordinate_plot( m2 = model_names.index(comparing_models[1]) y1 = zs[m1, :] y2 = zs[m2, :] - ax.fill_between( - x, - y1, - y2, - where=y2 >= y1, - facecolor=fill_between_lines_colors[0], - interpolate=True, - alpha=0.5, - ) - ax.fill_between( - x, - y1, - y2, - where=y2 <= y1, - facecolor=fill_between_lines_colors[1], - interpolate=True, - alpha=0.5, - ) + + # Fill between lines + if fill_between_lines: + ax.fill_between( + x, + y1, + y2, + where=(y2 > y1), + facecolor=fill_between_lines_colors[0], + interpolate=False, + alpha=0.5, + ) + ax.fill_between( + x, + y1, + y2, + where=(y2 < y1), + facecolor=fill_between_lines_colors[1], + interpolate=False, + alpha=0.5, + ) + + if arrow_between_lines: + # Add vertical arrows + for xi, yi1, yi2 in zip(x, y1, y2): + if yi2 > yi1: + arrow_color = arrow_between_lines_colors[0] + elif yi2 < yi1: + arrow_color = arrow_between_lines_colors[1] + else: + arrow_color = None + arrow_length = yi2 - yi1 + ax.arrow( + xi, + yi1, + 0, + arrow_length, + color=arrow_color, + length_includes_head=True, + alpha=arrow_alpha, + width=0.05, + head_width=0.15, + ) ax.set_xlim(-0.5, N - 0.5) ax.set_xticks(range(N))