Skip to content

Commit

Permalink
Merge branch 'main' into feature/1012_lee1043_stats-MoV_xcdat
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Apr 30, 2024
2 parents 5aac6f7 + f7d38c8 commit abbbd9a
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,16 @@ def parallel_coordinate_plot(
arrow_between_lines=False,
arrow_between_lines_colors=("red", "green"),
arrow_alpha=1,
arrow_width=0.05,
arrow_linewidth=0,
arrow_head_width=0.15,
arrow_head_length=0.15,
vertical_center=None,
vertical_center_line=False,
vertical_center_line_label=None,
ymax=None,
ymin=None,
debug=False,
):
"""
Parameters
Expand Down Expand Up @@ -99,11 +104,15 @@ def parallel_coordinate_plot(
- `arrow_between_lines`: bool, default=False, place arrows between two lines for models in comparing_models
- `arrow_between_lines_colors`: tuple or list containing two strings of colors for arrow between the two lines. Default=('red', 'green')
- `arrow_alpha`: float, default=1, transparency of arrow (faction between 0 to 1)
- `arrow_width`: float, default is 0.05, width of arrow
- `arrow_linewidth`: float, default is 0, width of arrow edge line
- `arrow_head_width`: float, default is 0.15, widht of arrow head
- `arrow_head_length`: float, default is 0.15, length of arrow head
- `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number
- `vertical_center_line`: bool, default=False, show median as line
- `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off"
- `ymax`: int or float, default=None, specify value of vertical axis top
- `ymin`: int or float, default=None, specify value of vertical axis bottom
- `ymax`: int or float or string ('percentile'), default=None, specify value of vertical axis top. If percentile, 95th percentile or extended for top
- `ymin`: int or float or string ('percentile'), default=None, specify value of vertical axis bottom. If percentile, 5th percentile or extended for bottom
Return
------
Expand All @@ -117,6 +126,7 @@ def parallel_coordinate_plot(
2023-03 median centered option added
2023-04 vertical center option diversified (median, mean, or given number)
2024-03 parameter added for violin plot label
2024-04 parameters added for arrow and option added for ymax/ymin setting
"""
params = {
"legend.fontsize": "large",
Expand All @@ -143,6 +153,10 @@ def parallel_coordinate_plot(
ymin=ymin,
)

if debug:
print("ymins:", ymins)
print("ymaxs:", ymaxs)

# Prepare plot
if N > 20:
if xtick_labelsize is None:
Expand Down Expand Up @@ -317,8 +331,8 @@ def parallel_coordinate_plot(
alpha=0.5,
)

# Add vertical arrows
if arrow_between_lines:
# Add vertical arrows
for xi, yi1, yi2 in zip(x, y1, y2):
if yi2 > yi1:
arrow_color = arrow_between_lines_colors[0]
Expand All @@ -335,8 +349,11 @@ def parallel_coordinate_plot(
color=arrow_color,
length_includes_head=True,
alpha=arrow_alpha,
width=0.05,
head_width=0.15,
width=arrow_width,
linewidth=arrow_linewidth,
head_width=arrow_head_width,
head_length=arrow_head_length,
zorder=999,
)

ax.set_xlim(-0.5, N - 0.5)
Expand Down Expand Up @@ -421,15 +438,28 @@ def _data_transform(
# Data to plot
ys = data # stacked y-axis values
N = ys.shape[1] # number of vertical axis (i.e., =len(metric_names))

if ymax is None:
ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value)
else:
ymaxs = np.repeat(ymax, N)
try:
if isinstance(ymax, str) and ymax == "percentile":
ymaxs = np.nanpercentile(ys, 95, axis=0)
else:
ymaxs = np.repeat(ymax, N)
except ValueError:
print(f"Invalid input for ymax: {ymax}")

if ymin is None:
ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value)
else:
ymins = np.repeat(ymin, N)
try:
if isinstance(ymin, str) and ymin == "percentile":
ymins = np.nanpercentile(ys, 5, axis=0)
else:
ymins = np.repeat(ymin, N)
except ValueError:
print(f"Invalid input for ymin: {ymin}")

ymeds = np.nanmedian(ys, axis=0) # median
ymean = np.nanmean(ys, axis=0) # mean
Expand All @@ -439,14 +469,17 @@ def _data_transform(
ymids = ymeds
elif vertical_center == "mean":
ymids = ymean
else:
elif isinstance(vertical_center, float) or isinstance(vertical_center, int):
ymids = np.repeat(vertical_center, N)
else:
raise ValueError(f"vertical center {vertical_center} unknown.")

for i in range(0, N):
max_distance_from_middle = max(
distance_from_middle = max(
abs(ymaxs[i] - ymids[i]), abs(ymids[i] - ymins[i])
)
ymaxs[i] = ymids[i] + max_distance_from_middle
ymins[i] = ymids[i] - max_distance_from_middle
ymaxs[i] = ymids[i] + distance_from_middle
ymins[i] = ymids[i] - distance_from_middle

dys = ymaxs - ymins
if ymin is None:
Expand Down
24 changes: 12 additions & 12 deletions pcmdi_metrics/sea_ice/lib/sea_ice_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ def write(self):
# ------------------------------------
# Define region coverage in functions
# ------------------------------------
def central_arctic(ds, ds_var, xvar, yvar):
def central_arctic(ds, ds_var, xvar, yvar, pole):
if (ds[xvar] > 180).any(): # 0 to 360
data_ca1 = ds[ds_var].where(
(
(ds[yvar] > 80)
& (ds[yvar] <= 87.2)
& (ds[yvar] <= pole)
& ((ds[xvar] > 240) | (ds[xvar] <= 90))
),
0,
)
data_ca2 = ds[ds_var].where(
((ds[yvar] > 65) & (ds[yvar] < 87.2))
((ds[yvar] > 65) & (ds[yvar] < pole))
& ((ds[xvar] > 90) & (ds[xvar] <= 240)),
0,
)
Expand All @@ -81,14 +81,14 @@ def central_arctic(ds, ds_var, xvar, yvar):
data_ca1 = ds[ds_var].where(
(
(ds[yvar] > 80)
& (ds[yvar] <= 87.2)
& (ds[yvar] <= pole)
& (ds[xvar] > -120)
& (ds[xvar] <= 90)
),
0,
)
data_ca2 = ds[ds_var].where(
((ds[yvar] > 65) & (ds[yvar] < 87.2))
((ds[yvar] > 65) & (ds[yvar] < pole))
& ((ds[xvar] > 90) | (ds[xvar] <= -120)),
0,
)
Expand Down Expand Up @@ -180,8 +180,8 @@ def indian_ocean(ds, ds_var, xvar, yvar):
return data_io


def arctic(ds, ds_var, xvar, yvar):
data_arctic = ds[ds_var].where(ds[yvar] > 0, 0)
def arctic(ds, ds_var, xvar, yvar, pole):
data_arctic = ds[ds_var].where((ds[yvar] > 0) & (ds[yvar] < pole), 0)
return data_arctic


Expand All @@ -190,13 +190,13 @@ def antarctic(ds, ds_var, xvar, yvar):
return data_antarctic


def choose_region(region, ds, ds_var, xvar, yvar):
def choose_region(region, ds, ds_var, xvar, yvar, pole):
if region == "arctic":
return arctic(ds, ds_var, xvar, yvar)
return arctic(ds, ds_var, xvar, yvar, pole)
elif region == "na":
return north_atlantic(ds, ds_var, xvar, yvar)
elif region == "ca":
return central_arctic(ds, ds_var, xvar, yvar)
return central_arctic(ds, ds_var, xvar, yvar, pole)
elif region == "np":
return north_pacific(ds, ds_var, xvar, yvar)
elif region == "antarctic":
Expand Down Expand Up @@ -236,14 +236,14 @@ def get_clim(total_extent, ds, ds_var):
return clim


def process_by_region(ds, ds_var, ds_area):
def process_by_region(ds, ds_var, ds_area, pole):
regions_list = ["arctic", "antarctic", "ca", "na", "np", "sa", "sp", "io"]
clims = {}
means = {}
for region in regions_list:
xvar = find_lon(ds)
yvar = find_lat(ds)
data = choose_region(region, ds, ds_var, xvar, yvar)
data = choose_region(region, ds, ds_var, xvar, yvar, pole)
total_extent, te_mean = get_total_extent(data, ds_area)
clim = get_clim(total_extent, ds, ds_var)
clims[region] = clim
Expand Down
7 changes: 7 additions & 0 deletions pcmdi_metrics/sea_ice/lib/sea_ice_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,11 @@ def create_sea_ice_parser():
default=True,
help="Option for generate individual plots for models: True (default) / False",
)

parser.add_argument(
"--pole",
type=float,
default=90.1,
help="Set to a latitude value to exclude sea ice at North pole. Must be > 80.",
)
return parser
7 changes: 4 additions & 3 deletions pcmdi_metrics/sea_ice/sea_ice_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
osyear = parameter.osyear
oeyear = parameter.oeyear
plot = parameter.plot
pole = parameter.pole

print("Model list:", model_list)
model_list.sort()
Expand Down Expand Up @@ -106,7 +107,7 @@
mask = create_land_sea_mask(obs, lon_key=xvar, lat_key=yvar)
obs[obs_var] = obs[obs_var].where(mask < 1)
# Get regions
clims, means = lib.process_by_region(obs, obs_var, area_val)
clims, means = lib.process_by_region(obs, obs_var, area_val, pole)

arctic_clims = {
"arctic": clims["arctic"],
Expand Down Expand Up @@ -149,7 +150,7 @@
# Remove land areas (including lakes)
mask = create_land_sea_mask(obs, lon_key="lon", lat_key="lat")
obs[obs_var] = obs[obs_var].where(mask < 1)
clims, means = lib.process_by_region(obs, obs_var, area_val)
clims, means = lib.process_by_region(obs, obs_var, area_val, pole)
antarctic_clims = {
"antarctic": clims["antarctic"],
"io": clims["io"],
Expand Down Expand Up @@ -358,7 +359,7 @@
ds[var] = ds[var].where(mask < 1)

# Get regions
clims, means = lib.process_by_region(ds, var, area[area_var].data)
clims, means = lib.process_by_region(ds, var, area[area_var].data, pole)

ds.close()
# Running sum of all realizations
Expand Down

0 comments on commit abbbd9a

Please sign in to comment.