Skip to content

Commit

Permalink
Merge branch 'main' into lee1043-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed May 29, 2023
2 parents 95140e9 + f55afc5 commit 24d4575
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
20 changes: 10 additions & 10 deletions pcmdi_metrics/graphics/portrait_plot/portrait_plot_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def portrait_plot(
# ----------------
# Prepare plotting
# ----------------
data, num_divide = prepare_data(data, xaxis_labels, yaxis_labels, debug)
data, num_divide = prepare_data(data, xaxis_labels, yaxis_labels, debug=debug)

if num_divide not in [1, 2, 4]:
sys.exit("Error: Number of (stacked) array is not 1, 2, or 4.")
Expand All @@ -117,7 +117,7 @@ def portrait_plot(
num_divide_annotate = num_divide
else:
annotate_data, num_divide_annotate = prepare_data(
annotate_data, xaxis_labels, yaxis_labels, debug
annotate_data, xaxis_labels, yaxis_labels, debug=debug
)
if num_divide_annotate != num_divide:
sys.exit("Error: annotate_data does not have same size as data")
Expand Down Expand Up @@ -154,8 +154,8 @@ def portrait_plot(
if num_divide == 1:
ax, im = heatmap(
data,
yaxis_labels,
xaxis_labels,
yaxis_labels,
ax=ax,
invert_yaxis=invert_yaxis,
cmap=cmap,
Expand Down Expand Up @@ -335,7 +335,7 @@ def portrait_plot(
# ======================================================================
# Prepare data
# ----------------------------------------------------------------------
def prepare_data(data, xaxis_labels, yaxis_labels, debug):
def prepare_data(data, xaxis_labels, yaxis_labels, debug=False):
# In case data was given as list of arrays, convert it to numpy (stacked) array
if type(data) == list:
if debug:
Expand All @@ -362,7 +362,7 @@ def prepare_data(data, xaxis_labels, yaxis_labels, debug):
sys.exit("Error: Number of elements in yaxis_label mismatchs to the data")

if type(data) == np.ndarray:
data = np.squeeze(data)
# data = np.squeeze(data)
if len(data.shape) == 2:
num_divide = 1
elif len(data.shape) == 3:
Expand All @@ -386,17 +386,17 @@ def prepare_data(data, xaxis_labels, yaxis_labels, debug):
# Portrait plot 1: heatmap-style (no triangle)
# (Inspired from: https://matplotlib.org/devdocs/gallery/images_contours_and_fields/image_annotated_heatmap.html)
# ----------------------------------------------------------------------
def heatmap(data, row_labels, col_labels, ax=None, invert_yaxis=False, **kwargs):
def heatmap(data, xaxis_labels, yaxis_labels, ax=None, invert_yaxis=False, **kwargs):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (M, N).
row_labels
yaxis_labels
A list or array of length M with the labels for the rows.
col_labels
xaxis_labels
A list or array of length N with the labels for the columns.
ax
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
Expand All @@ -419,8 +419,8 @@ def heatmap(data, row_labels, col_labels, ax=None, invert_yaxis=False, **kwargs)
# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]) + 0.5, minor=False)
ax.set_yticks(np.arange(data.shape[0]) + 0.5, minor=False)
ax.set_xticklabels(col_labels)
ax.set_yticklabels(row_labels)
ax.set_xticklabels(xaxis_labels)
ax.set_yticklabels(yaxis_labels)
ax.tick_params(which="minor", bottom=False, left=False)

return ax, im
Expand Down
18 changes: 11 additions & 7 deletions pcmdi_metrics/graphics/share/read_json_mean_clim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pandas as pd

from pcmdi_metrics.variability_mode.lib import sort_human


def read_mean_clim_json_files(
json_list, regions=None, stats=None, mip=None, debug=False
Expand Down Expand Up @@ -121,7 +123,7 @@ def extract_stat(var, results_dict_var):

def extract_region_stat(var, results_dict_var):
model_list = sorted(list(results_dict_var["RESULTS"].keys()))
run_list = sorted(
run_list = sort_human(
list(results_dict_var["RESULTS"][model_list[0]]["default"].keys())
)
if "source" in run_list:
Expand All @@ -144,18 +146,20 @@ def extract_data(results_dict, var_list, region, stat, season, mip, debug=False)
Return a pandas dataframe for metric numbers at given region/stat/season.
Rows: models, Columns: variables (i.e., 2d array)
"""
if "rlut" in list(results_dict["rlut"]["RESULTS"].keys()):
model_list = sorted(list(results_dict["rlut"]["RESULTS"].keys()))
if "rlut" in list(results_dict.keys()):
if "rlut" in list(results_dict["rlut"]["RESULTS"].keys()):
model_list = sorted(list(results_dict["rlut"]["RESULTS"].keys()))
else:
model_list = sorted(list(results_dict[var_list[0]]["RESULTS"].keys()))

data_list = []
for model in model_list:
if "rlut" in list(results_dict["rlut"]["RESULTS"].keys()):
run_list = list(results_dict["rlut"]["RESULTS"][model]["default"].keys())
if "rlut" in list(results_dict.keys()):
if "rlut" in list(results_dict["rlut"]["RESULTS"].keys()):
run_list = sort_human(list(results_dict["rlut"]["RESULTS"][model]["default"].keys()))
else:
run_list = list(
results_dict[var_list[0]]["RESULTS"][model]["default"].keys()
run_list = sort_human(list(
results_dict[var_list[0]]["RESULTS"][model]["default"].keys())
)

if debug:
Expand Down

0 comments on commit 24d4575

Please sign in to comment.