Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mean clim patch #920

Merged
merged 19 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def parallel_coordinate_plot(
metric_names,
model_names,
models_to_highlight=list(),
models_to_highlight_colors=None,
fig=None,
ax=None,
figsize=(15, 5),
Expand All @@ -23,7 +24,6 @@ def parallel_coordinate_plot(
violin_colors=("lightgrey", "pink"),
title=None,
identify_all_models=True,
xtick_labels=None,
xtick_labelsize=None,
ytick_labelsize=None,
colormap="viridis",
Expand All @@ -37,6 +37,11 @@ def parallel_coordinate_plot(
comparing_models=None,
fill_between_lines=False,
fill_between_lines_colors=("green", "red"),
vertical_center=None,
vertical_center_line=False,
vertical_center_line_label=None,
ymax=None,
ymin=None,
):
"""
Parameters
Expand All @@ -45,6 +50,7 @@ def parallel_coordinate_plot(
- `metric_names`: list, names of metrics for individual vertical axes (axis=1)
- `model_names`: list, name of models for markers/lines (axis=0)
- `models_to_highlight`: list, default=None, List of models to highlight as lines
- `models_to_highlight_colors`: list, default=None, List of colors for models to highlight as lines
- `fig`: `matplotlib.figure` instance to which the parallel coordinate plot is plotted.
If not provided, use current axes or create a new one. Optional.
- `ax`: `matplotlib.axes.Axes` instance to which the parallel coordinate plot is plotted.
Expand All @@ -69,15 +75,23 @@ def parallel_coordinate_plot(
- `comparing_models`: tuple or list containing two strings for models to compare with colors filled between the two lines.
- `fill_between_lines`: bool, default=False, fill color between lines for models in comparing_models
- `fill_between_lines_colors`: tuple or list containing two strings for colors filled between the two lines. Default=('green', 'red')
- `vertical_center`: string ("median", "mean")/float/integer, default=None, adjust range of vertical axis to set center of vertical axis as median, mean, or given number
- `vertical_center_line`: bool, default=False, show median as line
- `vertical_center_line_label`: str, default=None, label in legend for the horizontal vertical center line. If not given, it will be automatically assigned. It can be turned off by "off"
- `ymax`: int or float, default=None, specify value of vertical axis top
- `ymin`: int or float, default=None, specify value of vertical axis bottom

Return
------
- `fig`: matplotlib component for figure
- `ax`: matplotlib component for axis

Author: Jiwoo Lee @ LLNL (2021. 7)
Last update: 2022. 9
Inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
Update history:
2021-07 Plotting code created. Inspired by https://stackoverflow.com/questions/8230638/parallel-coordinates-plot-in-matplotlib
2022-09 violin plots added
2023-03 median centered option added
2023-04 vertical center option diversified (median, mean, or given number)
"""
params = {
"legend.fontsize": "large",
Expand All @@ -92,13 +106,16 @@ def parallel_coordinate_plot(
_quick_qc(data, model_names, metric_names, model_names2=model_names2)

# Transform data for plotting
zs, N, ymins, ymaxs, df_stacked, df2_stacked = _data_transform(
zs, zs_middle, N, ymins, ymaxs, df_stacked, df2_stacked = _data_transform(
data,
metric_names,
model_names,
model_names2=model_names2,
group1_name=group1_name,
group2_name=group2_name,
vertical_center=vertical_center,
ymax=ymax,
ymin=ymin,
)

# Prepare plot
Expand All @@ -123,8 +140,6 @@ def parallel_coordinate_plot(

if fig is None and ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
ax = ax

axes = [ax] + [ax.twinx() for i in range(N - 1)]

Expand Down Expand Up @@ -168,7 +183,11 @@ def parallel_coordinate_plot(
showextrema=False,
)
for pc in violin["bodies"]:
pc.set_facecolor(violin_colors[0])
if isinstance(violin_colors, tuple) or isinstance(violin_colors, list):
violin_color = violin_colors[0]
else:
violin_color = violin_colors
pc.set_facecolor(violin_color)
pc.set_edgecolor("None")
pc.set_alpha(0.8)
else:
Expand All @@ -194,10 +213,17 @@ def parallel_coordinate_plot(
marker_types = ["o", "s", "*", "^", "X", "D", "p"]
markers = list(flatten([[marker] * len(colors) for marker in marker_types]))
colors *= len(marker_types)
mh_index = 0
for j, model in enumerate(model_names):
# to just draw straight lines between the axes:
if model in models_to_highlight:
ax.plot(range(N), zs[j, :], "-", c=colors[j], label=model, lw=3)

if models_to_highlight_colors is not None:
color = models_to_highlight_colors[mh_index]
else:
color = colors[j]
ax.plot(range(N), zs[j, :], "-", c=color, label=model, lw=3)
mh_index += 1
else:
if identify_all_models:
ax.plot(
Expand All @@ -208,6 +234,13 @@ def parallel_coordinate_plot(
label=model,
clip_on=False,
)

if vertical_center_line:
if vertical_center_line_label is None:
vertical_center_line_label = str(vertical_center)
elif vertical_center_line_label == "off":
vertical_center_line_label = None
ax.plot(range(N), zs_middle, "-", c="k", label=vertical_center_line_label, lw=1)

# Fill between lines
if fill_between_lines and (comparing_models is not None):
Expand All @@ -226,6 +259,7 @@ def parallel_coordinate_plot(
where=y2 >= y1,
facecolor=fill_between_lines_colors[0],
interpolate=True,
alpha=0.5,
)
ax.fill_between(
x,
Expand All @@ -234,6 +268,7 @@ def parallel_coordinate_plot(
where=y2 <= y1,
facecolor=fill_between_lines_colors[1],
interpolate=True,
alpha=0.5,
)

ax.set_xlim(-0.5, N - 0.5)
Expand Down Expand Up @@ -287,21 +322,54 @@ def _data_transform(
model_names2=None,
group1_name="group1",
group2_name="group2",
vertical_center=None,
ymax=None,
ymin=None,
):
# Data to plot
ys = data # stacked y-axis values
N = ys.shape[1] # number of vertical axis (i.e., =len(metric_names))
ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value)
ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value)
if ymax is None:
ymaxs = np.nanmax(ys, axis=0) # maximum (ignore nan value)
else:
ymaxs = np.repeat(ymax, N)

if ymin is None:
ymins = np.nanmin(ys, axis=0) # minimum (ignore nan value)
else:
ymins = np.repeat(ymin, N)

ymeds = np.nanmedian(ys, axis=0) # median
ymean = np.nanmean(ys, axis=0) # mean

if vertical_center is not None:
if vertical_center == "median":
ymids = ymeds
elif vertical_center == "mean":
ymids = ymean
else:
ymids = np.repeat(vertical_center, N)
for i in range(0, N):
max_distance_from_middle = max(abs(ymaxs[i] - ymids[i]), abs(ymids[i] - ymins[i]))
ymaxs[i] = ymids[i] + max_distance_from_middle
ymins[i] = ymids[i] - max_distance_from_middle

dys = ymaxs - ymins
ymins -= dys * 0.05 # add 5% padding below and above
ymaxs += dys * 0.05
if ymin is None:
ymins -= dys * 0.05 # add 5% padding below and above
if ymax is None:
ymaxs += dys * 0.05
dys = ymaxs - ymins

# Transform all data to be compatible with the main axis
zs = np.zeros_like(ys)
zs[:, 0] = ys[:, 0]
zs[:, 1:] = (ys[:, 1:] - ymins[1:]) / dys[1:] * dys[0] + ymins[0]

if vertical_center is not None:
zs_middle = (ymids[:] - ymins[:]) / dys[:] * dys[0] + ymins[0]
else:
zs_middle = (ymaxs[:] - ymins[:]) / 2 / dys[:] * dys[0] + ymins[0]

if model_names2 is not None:
print("Models in the second group:", model_names2)
Expand All @@ -324,7 +392,7 @@ def _data_transform(
group2_name=group2_name,
)

return zs, N, ymins, ymaxs, df_stacked, df2_stacked
return zs, zs_middle, N, ymins, ymaxs, df_stacked, df2_stacked


def _to_pd_dataframe(
Expand Down
19 changes: 15 additions & 4 deletions pcmdi_metrics/graphics/share/read_json_mean_clim.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def read_mean_clim_json_files(
var += "-" + str(int(dict_temp["Variable"]["level"] / 100.0)) # Pa to hPa
results_dict[var] = dict_temp
unit = extract_unit(var, results_dict[var])
var_unit = var + " [" + unit + "]"
if unit is not None:
var_unit = var + " [" + unit + "]"
else:
var_unit = var
var_list.append(var)
var_unit_list.append(var_unit)
var_ref_dict[var] = extract_ref(var, results_dict[var])
Expand Down Expand Up @@ -90,13 +93,19 @@ def read_mean_clim_json_files(

def extract_unit(var, results_dict_var):
model_list = sorted(list(results_dict_var["RESULTS"].keys()))
units = results_dict_var["RESULTS"][model_list[0]]["units"]
try:
units = results_dict_var["RESULTS"][model_list[0]]["units"]
except Exception as e:
units = None
return units


def extract_ref(var, results_dict_var):
model_list = sorted(list(results_dict_var["RESULTS"].keys()))
ref = results_dict_var["RESULTS"][model_list[0]]["default"]["source"]
try:
ref = results_dict_var["RESULTS"][model_list[0]]["default"]["source"]
except Exception as e:
ref = None
return ref


Expand Down Expand Up @@ -152,7 +161,9 @@ def extract_data(results_dict, var_list, region, stat, season, mip, debug=False)
if debug:
print("model, run_list:", model, run_list)

run_list.remove("source")
if "source" in run_list:
run_list.remove("source")

for run in run_list:
tmp_list = []
for var in var_list:
Expand Down
10 changes: 9 additions & 1 deletion pcmdi_metrics/mean_climate/lib/create_mean_climate_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def create_mean_climate_parser():
required=False,
)

parser.add_argument(
"--varname_in_test_data",
type=ast.literal_eval,
dest="varname_in_test_data",
help="Variable name in input model file",
required=False,
)

parser.add_argument(
"--regions",
type=ast.literal_eval,
Expand Down Expand Up @@ -256,4 +264,4 @@ def create_mean_climate_parser():
required=False,
)

return parser
return parser
27 changes: 22 additions & 5 deletions pcmdi_metrics/mean_climate/lib/load_and_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import xcdat as xc
import numpy as np

def load_and_regrid(data_path, varname, level=None, t_grid=None, decode_times=True, regrid_tool='regrid2', debug=False):
def load_and_regrid(data_path, varname, varname_in_file=None, level=None, t_grid=None, decode_times=True, regrid_tool='regrid2', debug=False):
"""Load data and regrid to target grid

Args:
data_path (str): full data path for nc or xml file
varname (str): variable name
varname_in_file (str): variable name if data array named differently
level (float): level to extract (unit in hPa)
t_grid (xarray.core.dataset.Dataset): target grid to regrid
decode_times (bool): Default is True. decode_times=False will be removed once obs4MIP written using xcdat
Expand All @@ -17,9 +18,12 @@ def load_and_regrid(data_path, varname, level=None, t_grid=None, decode_times=Tr
"""
if debug:
print('load_and_regrid start')


if varname_in_file is None:
varname_in_file = varname

# load data
ds = xcdat_open(data_path, data_var=varname, decode_times=decode_times) # NOTE: decode_times=False will be removed once obs4MIP written using xcdat
ds = xcdat_open(data_path, data_var=varname_in_file, decode_times=decode_times) # NOTE: decode_times=False will be removed once obs4MIP written using xcdat

# calendar quality check
if "calendar" in list(ds.time.attrs.keys()):
Expand Down Expand Up @@ -51,11 +55,24 @@ def load_and_regrid(data_path, varname, level=None, t_grid=None, decode_times=Tr

# regrid
if regrid_tool == 'regrid2':
ds_regridded = ds.regridder.horizontal(varname, t_grid, tool=regrid_tool)
ds_regridded = ds.regridder.horizontal(varname_in_file, t_grid, tool=regrid_tool)
elif regrid_tool in ['esmf', 'xesmf']:
regrid_tool = 'xesmf'
regrid_method = 'bilinear'
ds_regridded = ds.regridder.horizontal(varname, t_grid, tool=regrid_tool, method=regrid_method)
ds_regridded = ds.regridder.horizontal(varname_in_file, t_grid, tool=regrid_tool, method=regrid_method)

if varname != varname_in_file:
ds_regridded[varname] = ds_regridded[varname_in_file]

# preserve units
try:
units = ds[varname].units
except Exception as e:
print(e)
units = ""
print('units:', units)

ds_regridded[varname] = ds_regridded[varname].assign_attrs({'units': units})

if debug:
print('ds_regridded:', ds_regridded)
Expand Down
Loading