Skip to content

Commit

Permalink
Merge pull request #1020 from PCMDI/feature/1012_lee1043_stats-MoV_xcdat
Browse files Browse the repository at this point in the history
Update MoV code to use xCDAT
  • Loading branch information
lee1043 committed May 2, 2024
2 parents 81635a7 + 2bd8aec commit 6d6b688
Show file tree
Hide file tree
Showing 38 changed files with 2,643 additions and 1,671 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ repos:
- id: black

- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
rev: 5.13.2
hooks:
- id: isort
args: ["--honor-noqa"]
Expand All @@ -34,7 +34,7 @@ repos:
# Python linting
# =======================
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
rev: 7.0.0
hooks:
- id: flake8
args: ["--config=setup.cfg"]
Expand Down
4 changes: 2 additions & 2 deletions conda-env/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ dependencies:
- genutil=8.2.1
- cdutil=8.2.1
- cdp=1.7.0
- eofs=1.4.0
- eofs=1.4.1
- seaborn=0.12.2
- enso_metrics=1.1.1
- xcdat>=0.6.1
- xcdat>=0.7.0
- xmltodict=0.13.0
- setuptools=67.7.2
- netcdf4=1.6.3
Expand Down
1,102 changes: 452 additions & 650 deletions doc/jupyter/Demo/Demo_4_modes_of_variability.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pcmdi_metrics/graphics/portrait_plot/portrait_plot_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,12 @@ def portrait_plot(
# ----------------------------------------------------------------------
def prepare_data(data, xaxis_labels, yaxis_labels, debug=False):
# In case data was given as list of arrays, convert it to numpy (stacked) array
if type(data) == list:
if isinstance(data, list):
if debug:
print("data type is list")
print("len(data):", len(data))
if len(data) == 1: # list has only 1 array as element
if (type(data[0]) == np.ndarray) and (len(data[0].shape) == 2):
if isinstance(data[0], np.ndarray) and (len(data[0].shape) == 2):
data = data[0]
num_divide = 1
else:
Expand All @@ -366,7 +366,7 @@ def prepare_data(data, xaxis_labels, yaxis_labels, debug=False):
if data.shape[-2] != len(yaxis_labels) and len(yaxis_labels) > 0:
sys.exit("Error: Number of elements in yaxis_label mismatchs to the data")

if type(data) == np.ndarray:
if isinstance(data, np.ndarray):
# data = np.squeeze(data)
if len(data.shape) == 2:
num_divide = 1
Expand Down
8 changes: 5 additions & 3 deletions pcmdi_metrics/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# init for pcmdi_metrics.io
from .xcdat_openxml import xcdat_open # noqa # isort:skip
from .string_constructor import StringConstructor, fill_template # noqa # isort:skip
from . import base # noqa
from .base import MV2Json # noqa
from .default_regions_define import load_regions_specs # noqa
from .default_regions_define import region_subset # noqa
from .xcdat_dataset_io import ( # noqa
from .xcdat_dataset_io import ( # noqa # isort:skip
da_to_ds,
get_axis_list,
get_data_list,
get_grid,
get_latitude_bounds_key,
get_latitude_key,
get_latitude,
Expand All @@ -21,3 +22,4 @@
get_time_key,
select_subset,
)
from .regions import load_regions_specs, region_subset # noqa
2 changes: 1 addition & 1 deletion pcmdi_metrics/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import pcmdi_metrics
from pcmdi_metrics import LOG_LEVEL
from pcmdi_metrics.utils import StringConstructor
from pcmdi_metrics.io import StringConstructor

value = 0
cdms2.setNetcdfShuffleFlag(value) # where value is either 0 or 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Union

import xarray as xr
import xcdat as xc

from pcmdi_metrics.io import da_to_ds, get_longitude, select_subset


def load_regions_specs():
def load_regions_specs() -> dict:
regions_specs = {
# Mean Climate
"global": {},
Expand Down Expand Up @@ -35,7 +40,10 @@ def load_regions_specs():
"NAO": {"domain": {"latitude": (20.0, 80), "longitude": (-90, 40)}},
"SAM": {"domain": {"latitude": (-20.0, -90), "longitude": (0, 360)}},
"PNA": {"domain": {"latitude": (20.0, 85), "longitude": (120, 240)}},
"NPO": {"domain": {"latitude": (20.0, 85), "longitude": (120, 240)}},
"PDO": {"domain": {"latitude": (20.0, 70), "longitude": (110, 260)}},
"NPGO": {"domain": {"latitude": (20.0, 70), "longitude": (110, 260)}},
"AMO": {"domain": {"latitude": (0.0, 70), "longitude": (-80, 0)}},
# Monsoon domains for Wang metrics
# All monsoon domains
"AllMW": {"domain": {"latitude": (-40.0, 45.0), "longitude": (0.0, 360.0)}},
Expand All @@ -45,7 +53,8 @@ def load_regions_specs():
# South American Monsoon
"SAMM": {"domain": {"latitude": (-45.0, 0.0), "longitude": (240.0, 330.0)}},
# North African Monsoon
"NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (310.0, 60.0)}},
# "NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (310.0, 60.0)}},
"NAFM": {"domain": {"latitude": (0.0, 45.0), "longitude": (-50.0, 60.0)}},
# South African Monsoon
"SAFM": {"domain": {"latitude": (-45.0, 0.0), "longitude": (0.0, 90.0)}},
# Asian Summer Monsoon
Expand All @@ -70,55 +79,77 @@ def load_regions_specs():
return regions_specs


def region_subset(ds, regions_specs, region=None):
"""
d: xarray.Dataset
regions_specs: dict
region: string
def region_subset(
ds: Union[xr.Dataset, xr.DataArray],
region: str,
data_var: str = "variable",
regions_specs: dict = None,
debug: bool = False,
) -> Union[xr.Dataset, xr.DataArray]:
"""_summary_
Parameters
----------
ds : Union[xr.Dataset, xr.DataArray]
_description_
region : str
_description_
data_var : str, optional
_description_, by default None
regions_specs : dict, optional
_description_, by default None
debug: bool, optional
Turn on debug print, by default False
Returns
-------
Union[xr.Dataset, xr.DataArray]
_description_
"""
if isinstance(ds, xr.DataArray):
is_dataArray = True
ds = da_to_ds(ds, data_var)
else:
is_dataArray = False

if regions_specs is None:
regions_specs = load_regions_specs()

if "domain" in regions_specs[region]:
if "latitude" in regions_specs[region]["domain"]:
lat0 = regions_specs[region]["domain"]["latitude"][0]
lat1 = regions_specs[region]["domain"]["latitude"][1]
# proceed subset
ds = select_subset(ds, lat=(min(lat0, lat1), max(lat0, lat1)))
if debug:
print("region_subset, latitude subsetted, ds:", ds)

if "longitude" in regions_specs[region]["domain"]:
lon0 = regions_specs[region]["domain"]["longitude"][0]
lon1 = regions_specs[region]["domain"]["longitude"][1]

# check original dataset longitude range
lon_min = get_longitude(ds).min().values.item()
lon_max = get_longitude(ds).max().values.item()

# Check if longitude range swap is needed
if min(lon0, lon1) < 0:
# when subset region lon is defined in (-180, 180) range
if min(lon_min, lon_max) < 0:
# if original data lon range is (-180, 180), no treatment needed
pass
else:
# if original data lon range is (0, 360), convert and swap lon
ds = xc.swap_lon_axis(ds, to=(-180, 180))

# proceed subset
# ds = select_subset(ds, lon=(min(lon0, lon1), max(lon0, lon1)))
ds = select_subset(ds, lon=(lon0, lon1))
if debug:
print("region_subset, longitude subsetted, ds:", ds)

if (region is None) or (
(region is not None) and (region not in list(regions_specs.keys()))
):
print("Error: region not defined")
# return the same type
if is_dataArray:
return ds[data_var]
else:
if "domain" in list(regions_specs[region].keys()):
if "latitude" in list(regions_specs[region]["domain"].keys()):
lat0 = regions_specs[region]["domain"]["latitude"][0]
lat1 = regions_specs[region]["domain"]["latitude"][1]
# proceed subset
if "latitude" in (ds.coords.dims):
ds = ds.sel(latitude=slice(lat0, lat1))
elif "lat" in (ds.coords.dims):
ds = ds.sel(lat=slice(lat0, lat1))

if "longitude" in list(regions_specs[region]["domain"].keys()):
lon0 = regions_specs[region]["domain"]["longitude"][0]
lon1 = regions_specs[region]["domain"]["longitude"][1]

# check original dataset longitude range
if "longitude" in (ds.coords.dims):
lon_min = ds.longitude.min()
lon_max = ds.longitude.max()
elif "lon" in (ds.coords.dims):
lon_min = ds.lon.min()
lon_max = ds.lon.max()

# longitude range swap if needed
if (
min(lon0, lon1) < 0
): # when subset region lon is defined in (-180, 180) range
if (
min(lon_min, lon_max) < 0
): # if original data lon range is (-180, 180) no treatment needed
pass
else: # if original data lon range is (0, 360), convert swap lon
ds = xc.swap_lon_axis(ds, to=(-180, 180))

# proceed subset
if "longitude" in (ds.coords.dims):
ds = ds.sel(longitude=slice(lon0, lon1))
elif "lon" in (ds.coords.dims):
ds = ds.sel(lon=slice(lon0, lon1))

return ds
return ds
99 changes: 99 additions & 0 deletions pcmdi_metrics/io/string_constructor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import warnings


class StringConstructor:
"""
This class aims at spotting keywords in a string and replacing them.
"""

def __init__(self, template=None):
"""
Instantiates a StringConstructor object.
"""
self.template = template
# Generate the keys and set them to empty
keys = self.keys()
for k in keys:
setattr(self, k, "")

def keys(self, template=None):
if template is None:
template = self.template
if template is None:
return []
# Determine the keywords in the template
keys = []
template_split = template.split("%(")[1:]
if len(template_split) > 0:
for k in template_split:
sp = k.split(")")
if sp[0] not in keys:
keys.append(sp[0])
return keys

def construct(self, template=None, **kw):
"""
Accepts a string with an unlimited number of keywords to replace.
"""
if template is None:
template = self.template
# Replace the keywords with their values
for k in self.keys():
if k not in kw:
warnings.warn(f"Keyword '{k}' not provided for filling the template.")
template = template.replace("%(" + k + ")", kw.get(k, getattr(self, k, "")))
return template

def reverse(self, name, debug=False):
"""
The reverse function attempts to take a template and derive its keyword values based on name parameter.
"""
out = {}
template = self.template
for k in self.keys():
sp = template.split("%%(%s)" % k)
i1 = name.find(sp[0]) + len(sp[0])
j1 = sp[1].find("%(")
if j1 == -1:
if sp[1] == "":
val = name[i1:]
else:
i2 = name.find(sp[1])
val = name[i1:i2]
else:
i2 = name[i1:].find(sp[1][:j1])
val = name[i1 : i1 + i2]
template = template.replace("%%(%s)" % k, val)
out[k] = val
if self.construct(self.template, **out) != name:
raise ValueError("Invalid pattern sent")
return out

def __call__(self, *args, **kw):
"""default call is construct function"""
return self.construct(*args, **kw)


def fill_template(template: str, **kwargs) -> str:
"""
Fill in a template string with keyword values.
Parameters
----------
- template (str): The template string containing keywords of the form '%(keyword)'.
- kwargs (dict): Keyword arguments with values to replace in the template.
Returns
-------
- str: The filled-in string with replaced keywords.
Examples
--------
>>> from pcmdi_metrics.utils import fill_template
>>> template = "This is a %(adjective) %(noun) that %(verb)."
>>> filled_string = fill_template(template, adjective="great", noun="example", verb="works")
>>> print(filled_string) # It will print "This is a great example that works."
"""
filler = StringConstructor(template)
filled_template = filler.construct(**kwargs)
return filled_template
Loading

0 comments on commit 6d6b688

Please sign in to comment.