Skip to content

Commit

Permalink
Merge branch 'main' into ao_add_pole_opt
Browse files Browse the repository at this point in the history
  • Loading branch information
acordonez committed Apr 30, 2024
2 parents 0965bb9 + fc86586 commit cb8d207
Showing 1 changed file with 44 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ def parallel_coordinate_plot(
arrow_between_lines=False,
arrow_between_lines_colors=("red", "green"),
arrow_alpha=1,
arrow_width=0.05,
arrow_linewidth=0,
arrow_head_width=0.15,
arrow_head_length=0.15,
vertical_center=None,
vertical_center_line=False,
vertical_center_line_label=None,
ymax=None,
ymin=None,
debug=False,
):
"""
Parameters
Expand Down Expand Up @@ -99,11 +104,15 @@ def parallel_coordinate_plot(
- `arrow_between_lines`: bool, default=False, place arrows between two lines for models in comparing_models
- `arrow_between_lines_colors`: tuple or list containing two strings of colors for arrow between the two lines. Default=('red', 'green')
- `arrow_alpha`: float, default=1, transparency of arrow (faction between 0 to 1)
- `arrow_width`: float, default is 0.05, width of arrow
- `arrow_linewidth`: float, default is 0, width of arrow edge line
- `arrow_head_width`: float, default is 0.15, widht of arrow head
- `arrow_head_length`: float, default is 0.15, length of arrow head
- `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number
- `vertical_center_line`: bool, default=False, show median as line
- `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off"
- `ymax`: int or float, default=None, specify value of vertical axis top
- `ymin`: int or float, default=None, specify value of vertical axis bottom
- `ymax`: int or float or string ('percentile'), default=None, specify value of vertical axis top. If percentile, 95th percentile or extended for top
- `ymin`: int or float or string ('percentile'), default=None, specify value of vertical axis bottom. If percentile, 5th percentile or extended for bottom
Return
------
Expand All @@ -117,6 +126,7 @@ def parallel_coordinate_plot(
2023-03 median centered option added
2023-04 vertical center option diversified (median, mean, or given number)
2024-03 parameter added for violin plot label
2024-04 parameters added for arrow and option added for ymax/ymin setting
"""
params = {
"legend.fontsize": "large",
Expand All @@ -143,6 +153,10 @@ def parallel_coordinate_plot(
ymin=ymin,
)

if debug:
print("ymins:", ymins)
print("ymaxs:", ymaxs)

# Prepare plot
if N > 20:
if xtick_labelsize is None:
Expand Down Expand Up @@ -317,8 +331,8 @@ def parallel_coordinate_plot(
alpha=0.5,
)

# Add vertical arrows
if arrow_between_lines:
# Add vertical arrows
for xi, yi1, yi2 in zip(x, y1, y2):
if yi2 > yi1:
arrow_color = arrow_between_lines_colors[0]
Expand All @@ -335,8 +349,11 @@ def parallel_coordinate_plot(
color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha,
width=0.05,
head_width=0.15,
width=arrow_width,
linewidth=arrow_linewidth,
head_width=arrow_head_width,
head_length=arrow_head_length,
zorder=999,
)

ax.set_xlim(-0.5, N - 0.5)
Expand Down Expand Up @@ -421,15 +438,28 @@ def _data_transform(
# Data to plot
ys = data # stacked y-axis values
N = ys.shape[1] # number of vertical axis (i.e., =len(metric_names))

if ymax is None:
ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value)
else:
ymaxs = np.repeat(ymax, N)
try:
if isinstance(ymax, str) and ymax == "percentile":
ymaxs = np.nanpercentile(ys, 95, axis=0)
else:
ymaxs = np.repeat(ymax, N)
except ValueError:
print(f"Invalid input for ymax: {ymax}")

if ymin is None:
ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value)
else:
ymins = np.repeat(ymin, N)
try:
if isinstance(ymin, str) and ymin == "percentile":
ymins = np.nanpercentile(ys, 5, axis=0)
else:
ymins = np.repeat(ymin, N)
except ValueError:
print(f"Invalid input for ymin: {ymin}")

ymeds = np.nanmedian(ys, axis=0) # median
ymean = np.nanmean(ys, axis=0) # mean
Expand All @@ -439,14 +469,17 @@ def _data_transform(
ymids = ymeds
elif vertical_center == "mean":
ymids = ymean
else:
elif isinstance(vertical_center, float) or isinstance(vertical_center, int):
ymids = np.repeat(vertical_center, N)
else:
raise ValueError(f"vertical center {vertical_center} unknown.")

for i in range(0, N):
max_distance_from_middle = max(
distance_from_middle = max(
abs(ymaxs[i] - ymids[i]), abs(ymids[i] - ymins[i])
)
ymaxs[i] = ymids[i] + max_distance_from_middle
ymins[i] = ymids[i] - max_distance_from_middle
ymaxs[i] = ymids[i] + distance_from_middle
ymins[i] = ymids[i] - distance_from_middle

dys = ymaxs - ymins
if ymin is None:
Expand Down

0 comments on commit cb8d207

Please sign in to comment.