From b35606ccd64fc6cba5bbbcf239e3242e0d1b86f4 Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 21:55:04 -0800 Subject: [PATCH 1/7] add pcmdi method as an option --- pcmdi_metrics/utils/create_land_sea_mask.py | 454 +++++++++++++++++++- setup.py | 1 + 2 files changed, 444 insertions(+), 11 deletions(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index d40c9fd87..7eecbaf42 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -1,16 +1,23 @@ +import os +import sys +import time import warnings from typing import Union +import numpy as np import regionmask import xarray as xr import xcdat as xc +from pcmdi_metrics import resources + def create_land_sea_mask( obj: Union[xr.Dataset, xr.DataArray], lon_key: str = None, lat_key: str = None, as_boolean: bool = False, + method: str = "regionmask", ) -> xr.DataArray: """Generate a land-sea mask (1 for land, 0 for sea) for a given xarray Dataset or DataArray. @@ -24,6 +31,8 @@ def create_land_sea_mask( Name of DataArray for latitude, by default None as_boolean : bool, optional Set mask value to True (land) or False (ocean), by default False, thus 1 (land) and 0 (ocean). + method : str, optional + Method to use for creating the mask, either 'regionmask' or 'pcmdi', by default 'regionmask'. Returns ------- @@ -43,21 +52,35 @@ def create_land_sea_mask( Generate land-sea mask (land: True, sea: False): >>> mask = create_land_sea_mask(ds, as_boolean=True) + + Use PCMDI method: + + >>> mask = create_land_sea_mask(ds, method="pcmdi") + """ - # Create a land-sea mask using regionmask - land_mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110 - # Get the longitude and latitude from the xarray dataset - if lon_key is None: - lon_key = xc.axis.get_dim_keys(obj, axis="X") - if lat_key is None: - lat_key = xc.axis.get_dim_keys(obj, axis="Y") + # Create a land-sea mask + if method.lower() == "regionmask": + # Use regionmask + land_mask = regionmask.defined_regions.natural_earth_v5_0_0.land_110 - lon = obj[lon_key] - lat = obj[lat_key] + # Get the longitude and latitude from the xarray dataset + if lon_key is None: + lon_key = xc.axis.get_dim_keys(obj, axis="X") + if lat_key is None: + lat_key = xc.axis.get_dim_keys(obj, axis="Y") - # Mask the land-sea mask to match the dataset's coordinates - land_sea_mask = land_mask.mask(lon, lat=lat) + lon = obj[lon_key] + lat = obj[lat_key] + + # Mask the land-sea mask to match the dataset's coordinates + land_sea_mask = land_mask.mask(lon, lat=lat) + + elif method.lower() == "pcmdi": + # Use the PCMDI method developed by Taylor and Doutriaux (2000) + land_sea_mask = generate_land_sea_mask__pcmdi(obj) + else: + raise ValueError("Unknown method '%s'. Please choose 'regionmask' or 'pcmdi'") if not as_boolean: # Convert the land-sea mask to a boolean mask @@ -182,3 +205,412 @@ def apply_landmask( data_array = data_array.where(landfrac <= ocean_criteria) return data_array + + +def generate_land_sea_mask__pcmdi( + target_grid, + source=None, + data_var="sftlf", + maskname="lsmask", + regridTool="regrid2", + threshold_1=0.2, + threshold_2=0.3, + debug=False, +): + """Generates a best guess mask on any rectilinear grid, using the method described in `PCMDI's report #58`_ + + Parameters + ---------- + target_grid : xarray.Dataset + Either a xcdat/xarray Dataset with a grid, or a xcdat grid (rectilinear grid only) + source : xarray.Dataset, optional + A xcdat/xarray Dataset that contains a DataArray of a fractional (0.0 to 1.0) land sea mask, + where 1 means all land., by default None + data_var : str, optional + name of DataArray for land sea fraction/mask variable in `source`, by default "sftlf" + maskname : str, optional + Variable name for returning DataArray, by default "lsmask" + regridTool : str, optional + Which xcdat regridder tool to use, by default "regrid2" + threshold_1 : float, optional + Criteria for detecting cells with possible increment see report for detail difference threshold, by default 0.2 + threshold_2 : float, optional + Criteria for detecting cells with possible increment see report for detail water/land content threshold, by default 0.3 + debug : bool, optional + Switch to print more interim outputs to help debugging, by default False + + Returns + ------- + xarray.DataArray + landsea mask on target grid (1: land, 0: water). + + Raises + ------ + ValueError + _description_ + + References + ---------- + .. _PCMDI's report #58: https://pcmdi.llnl.gov/report/ab58.html + + History + ------- + 2023-06 The [original code](https://github.com/CDAT/cdutil/blob/master/cdutil/create_landsea_mask.py) was rewritten using xarray and xcdat by Jiwoo Lee + """ + + if source is None: + egg_pth = resources.resource_path() + source_path = os.path.join(egg_pth, "navy_land.nc") + if not os.path.isfile(source_path): + # pip install process places data files in different place, so checking here as well + source_path = os.path.join( + sys.prefix, "share/pcmdi_metrics", "navy_land.nc" + ) + ds = xc.open_dataset(source_path, decode_times=False).load() + else: + ds = source.copy() + if not isinstance(ds, xr.Dataset): + raise ValueError( + "ERROR: type of source, ", + type(source), + " is not acceptable. It should be ", + ) + + # Regrid + if target_grid.equals(ds): + ds_regrid = ds.copy() # testing purpose + else: + start_time_r = time.time() + ds_regrid = ds.regridder.horizontal(data_var, target_grid, tool=regridTool) + end_time_r = time.time() + + if debug: + print( + "Elapsed time (regridder.horizontal):", + end_time_r - start_time_r, + "seconds", + ) + + # Add missed information during the regrid process + # (this might be a bug... will report it to xcdat team later) + if "axis" not in ds_regrid[data_var].lat.attrs.keys(): + ds_regrid[data_var].lat.attrs["axis"] = "Y" + if "axis" not in ds_regrid[data_var].lon.attrs.keys(): + ds_regrid[data_var].lon.attrs["axis"] = "X" + if "bounds" not in ds_regrid[data_var].lat.attrs.keys(): + ds_regrid[data_var].lat.attrs["bounds"] = "lat_bnds" + if "bounds" not in ds_regrid[data_var].lon.attrs.keys(): + ds_regrid[data_var].lon.attrs["bounds"] = "lon_bnds" + if "units" not in ds_regrid[data_var].lat.attrs: + ds_regrid[data_var].lat.attrs["units"] = "degrees_north" + + # re-generate lat lon bounds (original bounds are 2d arrays where 1d array for each is expected) + ds_regrid = ( + ds_regrid.drop_vars( + [ + ds_regrid[data_var].lat.attrs["bounds"], + ds_regrid[data_var].lon.attrs["bounds"], + ] + ) + .bounds.add_bounds("Y") + .bounds.add_bounds("X") + ) + + # First guess, anything greater than 50% is land to ignore rivers and lakes + mask = xr.where(ds_regrid[data_var] > 0.5, 1, 0) + + if debug: + ds_regrid[data_var + "_regrid"] = ds_regrid[data_var].copy() + ds_regrid[data_var + "_firstGuess"] = mask + + # Improve + UL, UC, UR, ML, MR, LL, LC, LR = _create_surrounds( + ds_regrid, data_var=data_var, debug=debug + ) + + cont = True + i = 0 + + while cont: + mask_improved = _improve( + mask, + ds_regrid, + UL, + UC, + UR, + ML, + MR, + LL, + LC, + LR, + data_var=data_var, + threshold_1=threshold_1, + threshold_2=threshold_2, + regridTool=regridTool, + debug=debug, + ) + + if mask_improved.equals(mask) or i > 25: + cont = False + + mask = mask_improved.astype("i") + + if debug: + print("test i:", i) + + i += 1 + + mask = mask.rename(maskname) + + # Reverse the values (0 to 1 and 1 to 0) + reversed_mask = xr.where(mask == 0, 1, 0) + + return reversed_mask + + +def _create_surrounds(ds, data_var="sftlf", debug=False): + start_time = time.time() + data = ds[data_var].data + sh = list(data.shape) + L = ds["lon"] + bL = ds[ds.lon.attrs["bounds"]].data + + L_isCircular = _isCircular(L) + L_modulo = 360 + + if _isCircular(L) and bL[-1][1] - bL[0][0] % L_modulo == 0: + sh[0] = sh[0] - 2 + else: + sh[0] = sh[0] - 2 + sh[1] = sh[1] - 2 + + UL = np.ones(sh) + UC = np.ones(sh) + UR = np.ones(sh) + ML = np.ones(sh) + MR = np.ones(sh) + LL = np.ones(sh) + LC = np.ones(sh) + LR = np.ones(sh) + + if L_isCircular and bL[-1][1] - bL[0][0] % L_modulo == 0: + UC[:, :] = data[2:] + LC[:, :] = data[:-2] + ML[:, 1:] = data[1:-1, :-1] + ML[:, 0] = data[1:-1, -1] + MR[:, :-1] = data[1:-1, 1:] + MR[:, -1] = data[1:-1, 0] + UL[:, 1:] = data[2:, :-1] + UL[:, 0] = data[2:, -1] + UR[:, :-1] = data[2:, 1:] + UR[:, -1] = data[2:, 0] + LL[:, 1:] = data[:-2, :-1] + LL[:, 0] = data[:-2, -1] + LR[:, :-1] = data[:-2, 1:] + LR[:, -1] = data[:-2, 0] + else: + UC[:, :] = data[2:, 1:-1] + LC[:, :] = data[:-2, 1:-1] + ML[:, :] = data[1:-1, :-2] + MR[:, :] = data[1:-1, 2:] + UL[:, :] = data[2:, :-2] + UR[:, :] = data[2:, 2:] + LL[:, :] = data[:-2, :-2] + LR[:, :] = data[:-2, 2:] + + end_time = time.time() + if debug: + elapsed_time = end_time - start_time + print("Elapsed time (_create_surrounds):", elapsed_time, "seconds") + + return UL, UC, UR, ML, MR, LL, LC, LR + + +def _isCircular(lons): + baxis = lons[0] # beginning of axis + eaxis = lons[-1] # end of axis + deltaend = lons[-1] - lons[-2] # delta between two end points + eaxistest = eaxis + deltaend - baxis # test end axis + tol = 0.01 * deltaend + if abs(eaxistest - 360) < tol: + return True + else: + return False + + +def _improve( + mask, + ds_regrid, + UL, + UC, + UR, + ML, + MR, + LL, + LC, + LR, + data_var="sftlf", + threshold_1=0.2, + threshold_2=0.3, + regridTool="regrid2", + debug=False, +): + start_time = time.time() + + ds_mask_approx = _map2four( + mask, ds_regrid, data_var=data_var, regridTool=regridTool, debug=debug + ) + diff = ds_regrid[data_var] - ds_mask_approx[data_var] + + # Land point conversion + c1 = np.greater(diff, threshold_1) # xr.DataArray + c2 = np.greater(ds_regrid[data_var], threshold_2) # xr.DataArray + c = np.logical_and(c1, c2) + ds_regrid["c"] = c + + # Now figures out local maxima + cUL, cUC, cUR, cML, cMR, cLL, cLC, cLR = _create_surrounds(ds_regrid, data_var="c") + + L = ds_regrid["lon"] + bL = ds_regrid[ds_regrid.lon.attrs["bounds"]].data + + L_modulo = 360 + L_isCircular = _isCircular(L) + + if L_isCircular and bL[-1][1] - bL[0][0] % L_modulo == 0: + c = c[1:-1] # elimnitates north and south poles + tmp = ds_regrid[data_var].data[1:-1] + else: + c = c[1:-1, 1:-1] # elimnitates north and south poles + tmp = ds_regrid[data_var].data[1:-1, 1:-1] + m = np.logical_and(c, np.greater(tmp, np.where(cUL, UL, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cUC, UC, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cUR, UR, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cML, ML, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cMR, MR, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cLL, LL, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cLC, LC, 0.0))) + m = np.logical_and(m, np.greater(tmp, np.where(cLR, LR, 0.0))) + # Ok now update the mask by setting these points to land + mask2 = mask * 1.0 + if _isCircular(L) and bL[-1][1] - bL[0][0] % L_modulo == 0: + mask2[1:-1] = xr.where(m, 1, mask[1:-1]) + else: + mask2[1:-1, 1:-1] = xr.where(m, 1, mask[1:-1, 1:-1]) + + # ocean point conversion + c1 = np.less(diff, -threshold_1) + c2 = np.less(ds_regrid[data_var], 1.0 - threshold_2) + c = np.logical_and(c1, c2) + ds_regrid["c"] = c + + cUL, cUC, cUR, cML, cMR, cLL, cLC, cLR = _create_surrounds(ds_regrid, data_var="c") + + if L_isCircular and bL[-1][1] - bL[0][0] % L_modulo == 0: + c = c[1:-1] # elimnitates north and south poles + tmp = ds_regrid[data_var].data[1:-1] + else: + c = c[1:-1, 1:-1] # elimnitates north and south poles + tmp = ds_regrid[data_var].data[1:-1, 1:-1] + # Now figures out local maxima + m = np.logical_and(c, np.less(tmp, np.where(cUL, UL, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cUC, UC, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cUR, UR, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cML, ML, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cMR, MR, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cLL, LL, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cLC, LC, 1.0))) + m = np.logical_and(m, np.less(tmp, np.where(cLR, LR, 1.0))) + # Ok now update the mask by setting these points to ocean + if L_isCircular and bL[-1][1] - bL[0][0] % L_modulo == 0: + mask2[1:-1] = xr.where(m, 0, mask2[1:-1]) + else: + mask2[1:-1, 1:-1] = xr.where(m, 0, mask2[1:-1, 1:-1]) + + end_time = time.time() + if debug: + elapsed_time = end_time - start_time + print("Elapsed time (_improve):", elapsed_time, "seconds") + + return mask2 + + +def _map2four(mask, ds_regrid, data_var="sftlf", regridTool="regrid2", debug=False): + if debug: + print("mask.shape:", mask.shape) + print("ds_regrid[data_var].shape:", ds_regrid[data_var].shape) + + ds_tmp = ds_regrid.copy() + ds_tmp[data_var] = mask + + start_time_c = time.time() + + lons = ds_regrid.lon.data + lats = ds_regrid.lat.data + lonso = lons[::2] + lonse = lons[1::2] + latso = lats[::2] + latse = lats[1::2] + + lat_delta = (lats[-1] - lats[0]) / len(lats) * 2 + lon_delta = (lons[-1] - lons[0]) / len(lons) * 2 + + oo = xc.create_uniform_grid( + latso[0], latso[-1], lat_delta, lonso[0], lonso[-1], lon_delta + ) + oe = xc.create_uniform_grid( + latso[0], latso[-1], lat_delta, lonse[0], lonse[-1], lon_delta + ) + eo = xc.create_uniform_grid( + latse[0], latse[-1], lat_delta, lonso[0], lonso[-1], lon_delta + ) + ee = xc.create_uniform_grid( + latse[0], latse[-1], lat_delta, lonse[0], lonse[-1], lon_delta + ) + + end_time_c = time.time() + + doo = ds_tmp.regridder.horizontal(data_var, oo, tool=regridTool) + doe = ds_tmp.regridder.horizontal(data_var, oe, tool=regridTool) + deo = ds_tmp.regridder.horizontal(data_var, eo, tool=regridTool) + dee = ds_tmp.regridder.horizontal(data_var, ee, tool=regridTool) + + end_time_r = time.time() + + out = np.zeros(mask.shape, dtype="f") + + if debug: + print("out.shape:", out.shape) + print("doo.shape:", doo[data_var].data.shape) + print("doe.shape:", doe[data_var].data.shape) + print("deo.shape:", deo[data_var].data.shape) + print("dee.shape:", dee[data_var].data.shape) + + out[::2, ::2] = doo[data_var].data + out[::2, 1::2] = doe[data_var].data + out[1::2, ::2] = deo[data_var].data + out[1::2, 1::2] = dee[data_var].data + + ds_out = ds_regrid.copy() + ds_out[data_var] = (("lat", "lon"), out) + + end_time_o = time.time() + + end_time = time.time() + + if debug: + elapsed_time = end_time - start_time_c + print("Elapsed time (_map2four):", elapsed_time, "seconds") + print( + "Elapsed time (_map2four, create_uniform_grid):", + end_time_c - start_time_c, + "seconds", + ) + print( + "Elapsed time (_map2four, regridder.horizontal):", + end_time_r - end_time_c, + "seconds", + ) + print("Elapsed time (_map2four, out):", end_time_o - end_time_r, "seconds") + + return ds_out diff --git a/setup.py b/setup.py index eeca95a9d..76018bec3 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,7 @@ "share/default_regions.py", "share/DefArgsCIA.json", "pcmdi_metrics/precip_distribution/lib/cluster3_pdf.amt_regrid.360x180_IMERG_ALL_90S90N.nc", + "share/data/navy_land.nc", ), ), ) From 816eee8c0a78af7c3e5dfd3f3b9fe894f3f154ba Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 22:15:35 -0800 Subject: [PATCH 2/7] clean up --- pcmdi_metrics/utils/create_land_sea_mask.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index 7eecbaf42..9579299a4 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -76,15 +76,17 @@ def create_land_sea_mask( # Mask the land-sea mask to match the dataset's coordinates land_sea_mask = land_mask.mask(lon, lat=lat) + if not as_boolean: + # Convert the land-sea mask to a boolean mask + land_sea_mask = xr.where(land_sea_mask, 0, 1) + elif method.lower() == "pcmdi": # Use the PCMDI method developed by Taylor and Doutriaux (2000) land_sea_mask = generate_land_sea_mask__pcmdi(obj) else: raise ValueError("Unknown method '%s'. Please choose 'regionmask' or 'pcmdi'") - if not as_boolean: - # Convert the land-sea mask to a boolean mask - land_sea_mask = xr.where(land_sea_mask, 0, 1) + return land_sea_mask @@ -363,9 +365,9 @@ def generate_land_sea_mask__pcmdi( mask = mask.rename(maskname) # Reverse the values (0 to 1 and 1 to 0) - reversed_mask = xr.where(mask == 0, 1, 0) + #mask = xr.where(mask == 0, 1, 0) - return reversed_mask + return mask def _create_surrounds(ds, data_var="sftlf", debug=False): From d925ee1e92058aeaaaa99f3ead0d9fa07b72fc8b Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 22:25:50 -0800 Subject: [PATCH 3/7] fix as_boolen option for the pcmdi method --- pcmdi_metrics/utils/create_land_sea_mask.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index 9579299a4..e8f6e4619 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -83,6 +83,10 @@ def create_land_sea_mask( elif method.lower() == "pcmdi": # Use the PCMDI method developed by Taylor and Doutriaux (2000) land_sea_mask = generate_land_sea_mask__pcmdi(obj) + + if as_boolean: + land_sea_mask = xr.where(land_sea_mask==1, True, False) + else: raise ValueError("Unknown method '%s'. Please choose 'regionmask' or 'pcmdi'") From 1259bf029c7216fc6cd28aa67a75da3447a78fdf Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 22:27:44 -0800 Subject: [PATCH 4/7] fix as_boolen option for the pcmdi method --- pcmdi_metrics/utils/create_land_sea_mask.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index e8f6e4619..8dd743a20 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -77,21 +77,20 @@ def create_land_sea_mask( land_sea_mask = land_mask.mask(lon, lat=lat) if not as_boolean: - # Convert the land-sea mask to a boolean mask + # Convert the boolean land-sea mask to a 0/1 mask land_sea_mask = xr.where(land_sea_mask, 0, 1) elif method.lower() == "pcmdi": # Use the PCMDI method developed by Taylor and Doutriaux (2000) land_sea_mask = generate_land_sea_mask__pcmdi(obj) - + if as_boolean: - land_sea_mask = xr.where(land_sea_mask==1, True, False) - + # Convert the 0/1 land-sea mask to a boolean mask + land_sea_mask = xr.where(land_sea_mask == 1, True, False) + else: raise ValueError("Unknown method '%s'. Please choose 'regionmask' or 'pcmdi'") - - return land_sea_mask @@ -369,7 +368,7 @@ def generate_land_sea_mask__pcmdi( mask = mask.rename(maskname) # Reverse the values (0 to 1 and 1 to 0) - #mask = xr.where(mask == 0, 1, 0) + # mask = xr.where(mask == 0, 1, 0) return mask From fa89155272e6744278155a729f4b28c685ecd798 Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 22:34:32 -0800 Subject: [PATCH 5/7] bug fix --- pcmdi_metrics/utils/create_land_sea_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index 8dd743a20..f6b2021d0 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -86,7 +86,7 @@ def create_land_sea_mask( if as_boolean: # Convert the 0/1 land-sea mask to a boolean mask - land_sea_mask = xr.where(land_sea_mask == 1, True, False) + land_sea_mask = land_sea_mask.astype(bool) else: raise ValueError("Unknown method '%s'. Please choose 'regionmask' or 'pcmdi'") From 3ca0475c645b1d78ca839a8be004ad00d61cade5 Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 22:50:52 -0800 Subject: [PATCH 6/7] bug fix --- pcmdi_metrics/utils/create_land_sea_mask.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index f6b2021d0..4150a3ebb 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -76,7 +76,10 @@ def create_land_sea_mask( # Mask the land-sea mask to match the dataset's coordinates land_sea_mask = land_mask.mask(lon, lat=lat) - if not as_boolean: + if as_boolean: + # Convert the 0/nan land-sea mask to a boolean mask + land_sea_mask = xr.where(land_sea_mask, False, True) + else: # Convert the boolean land-sea mask to a 0/1 mask land_sea_mask = xr.where(land_sea_mask, 0, 1) From ab2c5929fd08a00685ee7a66bf1169928d72479b Mon Sep 17 00:00:00 2001 From: Jiwoo Lee Date: Tue, 27 Feb 2024 23:06:42 -0800 Subject: [PATCH 7/7] clean up comments --- pcmdi_metrics/utils/create_land_sea_mask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pcmdi_metrics/utils/create_land_sea_mask.py b/pcmdi_metrics/utils/create_land_sea_mask.py index 4150a3ebb..fc3d76d62 100644 --- a/pcmdi_metrics/utils/create_land_sea_mask.py +++ b/pcmdi_metrics/utils/create_land_sea_mask.py @@ -77,10 +77,10 @@ def create_land_sea_mask( land_sea_mask = land_mask.mask(lon, lat=lat) if as_boolean: - # Convert the 0/nan land-sea mask to a boolean mask + # Convert the 0 (land) & nan (ocean) land-sea mask to a boolean mask land_sea_mask = xr.where(land_sea_mask, False, True) else: - # Convert the boolean land-sea mask to a 0/1 mask + # Convert the boolean land-sea mask to a 1/0 mask land_sea_mask = xr.where(land_sea_mask, 0, 1) elif method.lower() == "pcmdi": @@ -88,7 +88,7 @@ def create_land_sea_mask( land_sea_mask = generate_land_sea_mask__pcmdi(obj) if as_boolean: - # Convert the 0/1 land-sea mask to a boolean mask + # Convert the 1/0 land-sea mask to a boolean mask land_sea_mask = land_sea_mask.astype(bool) else: