Skip to content

Commit

Permalink
Merge pull request #1006 from PCMDI/feature/1005_lee1043_land-sea-mask
Browse files Browse the repository at this point in the history
Land sea mask generation
  • Loading branch information
lee1043 committed Dec 18, 2023
2 parents b8f818d + 2a7d98f commit c6d5398
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 119 deletions.
165 changes: 56 additions & 109 deletions pcmdi_metrics/mean_climate/mean_climate_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
from collections import OrderedDict
from re import split

import cdms2
import cdutil
import numpy as np
import xcdat as xc

from pcmdi_metrics import resources
from pcmdi_metrics.io import load_regions_specs, region_subset
from pcmdi_metrics.mean_climate.lib import (
Expand All @@ -19,8 +14,11 @@
load_and_regrid,
mean_climate_metrics_to_json,
)
from pcmdi_metrics.utils import apply_landmask, create_land_sea_mask, create_target_grid
from pcmdi_metrics.variability_mode.lib import sort_human, tree

print("--- prepare mean climate metrics calculation ---")

parser = create_mean_climate_parser()
parameter = parser.get_parameter(argparse_vals_only=False)

Expand Down Expand Up @@ -78,103 +76,45 @@
regions_specs = load_regions_specs()

default_regions = ["global", "NHEX", "SHEX", "TROPICS"]
print(
"case_id: ",
case_id,
"\n",
"test_data_set:",
test_data_set,
"\n",
"realization:",
realization,
"\n",
"vars:",
vars,
"\n",
"varname_in_test_data:",
varname_in_test_data,
"\n",
"reference_data_set:",
reference_data_set,
"\n",
"target_grid:",
target_grid,
"\n",
"regrid_tool:",
regrid_tool,
"\n",
"regrid_tool_ocn:",
regrid_tool_ocn,
"\n",
"save_test_clims:",
save_test_clims,
"\n",
"test_clims_interpolated_output:",
test_clims_interpolated_output,
"\n",
"filename_template:",
filename_template,
"\n",
"sftlf_filename_template:",
sftlf_filename_template,
"\n",
"generate_sftlf:",
generate_sftlf,
"\n",
"regions_specs:",
regions_specs,
"\n",
"regions:",
regions,
"\n",
"test_data_path:",
test_data_path,
"\n",
"reference_data_path:",
reference_data_path,
"\n",
"metrics_output_path:",
metrics_output_path,
"\n",
"diagnostics_output_path:",
diagnostics_output_path,
"\n",
"debug:",
debug,
"\n",

config_variables = OrderedDict(
[
("case_id", case_id),
("test_data_set", test_data_set),
("realization", realization),
("vars", vars),
("varname_in_test_data", varname_in_test_data),
("reference_data_set", reference_data_set),
("target_grid", target_grid),
("regrid_tool", regrid_tool),
("regrid_tool_ocn", regrid_tool_ocn),
("save_test_clims", save_test_clims),
("test_clims_interpolated_output", test_clims_interpolated_output),
("filename_template", filename_template),
("sftlf_filename_template", sftlf_filename_template),
("generate_sftlf", generate_sftlf),
("regions_specs", regions_specs),
("regions", regions),
("test_data_path", test_data_path),
("reference_data_path", reference_data_path),
("metrics_output_path", metrics_output_path),
("diagnostics_output_path", diagnostics_output_path),
("debug", debug),
]
)

print("--- prepare mean climate metrics calculation ---")
for key, value in config_variables.items():
print(f"{key}: {value}")

# generate target grid
res = target_grid.split("x")
lat_res = float(res[0])
lon_res = float(res[1])
start_lat = -90.0 + lat_res / 2
start_lon = 0.0
end_lat = 90.0 - lat_res / 2
end_lon = 360.0 - lon_res
nlat = ((end_lat - start_lat) * 1.0 / lat_res) + 1
nlon = ((end_lon - start_lon) * 1.0 / lon_res) + 1
t_grid = xc.create_uniform_grid(
start_lat, end_lat, lat_res, start_lon, end_lon, lon_res
)
if debug:
print(
"type(t_grid):", type(t_grid)
) # Expected type is 'xarray.core.dataset.Dataset'
print("t_grid:", t_grid)
# identical target grid in cdms2 to use generateLandSeaMask function that is yet to exist in xcdat
t_grid_cdms2 = cdms2.createUniformGrid(
start_lat, nlat, lat_res, start_lon, nlon, lon_res
)
t_grid = create_target_grid(target_grid_resolution=target_grid)

# generate land sea mask for the target grid
sft = cdutil.generateLandSeaMask(t_grid_cdms2)
if debug:
print("sft:", sft)
print("sft.getAxisList():", sft.getAxisList())
sft = create_land_sea_mask(t_grid)

# add sft to target grid dataset
t_grid["sftlf"] = (["lat", "lon"], np.array(sft))
t_grid["sftlf"] = sft

if debug:
print("t_grid (after sftlf added):", t_grid)
t_grid.to_netcdf("target_grid.nc")
Expand All @@ -188,8 +128,6 @@
obs_file_path = os.path.join(egg_pth, obs_file_name)
with open(obs_file_path) as fo:
obs_dict = json.loads(fo.read())
# if debug:
# print('obs_dict:', json.dumps(obs_dict, indent=4, sort_keys=True))

print("--- start mean climate metrics calculation ---")

Expand Down Expand Up @@ -353,26 +291,35 @@
print("region:", region)

# land/sea mask -- conduct masking only for variable data array, not entire data
if ("land" in region.split("_")) or (
"ocean" in region.split("_")
if any(
keyword in region.split("_")
for keyword in ["land", "ocean"]
):
ds_test_tmp = ds_test.copy(deep=True)
ds_ref_tmp = ds_ref.copy(deep=True)
if "land" in region.split("_"):
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] != 0.0
ds_test_tmp[varname] = apply_landmask(
ds_test[varname],
landfrac=t_grid["sftlf"],
keep_over="land",
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] != 0.0
ds_ref_tmp[varname] = apply_landmask(
ds_ref[varname],
landfrac=t_grid["sftlf"],
keep_over="land",
)
elif "ocean" in region.split("_"):
ds_test_tmp[varname] = ds_test[varname].where(
t_grid["sftlf"] == 0.0
ds_test_tmp[varname] = apply_landmask(
ds_test[varname],
landfrac=t_grid["sftlf"],
keep_over="ocean",
)
ds_ref_tmp[varname] = ds_ref[varname].where(
t_grid["sftlf"] == 0.0
ds_ref_tmp[varname] = apply_landmask(
ds_ref[varname],
landfrac=t_grid["sftlf"],
keep_over="ocean",
)
print("mask done")
print("mask done")
else:
ds_test_tmp = ds_test
ds_ref_tmp = ds_ref
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import sys

import cdutil
import numpy as np
import pandas as pd
import xarray as xr
Expand All @@ -12,7 +11,8 @@
from scipy.stats import chi2
from xcdat.regridder import grid

import pcmdi_metrics
from pcmdi_metrics.io.base import Base
from pcmdi_metrics.utils import create_land_sea_mask


# ==================================================================================
Expand Down Expand Up @@ -94,9 +94,7 @@ def precip_variability_across_timescale(
outfilename = (
"PS_pr." + str(dfrq) + "_regrid.180x90_area.freq.mean_" + dat + ".json"
)
JSON = pcmdi_metrics.io.base.Base(
outdir.replace("%(output_type)", "metrics_results"), outfilename
)
JSON = Base(outdir.replace("%(output_type)", "metrics_results"), outfilename)
JSON.write(
psdmfm,
json_structure=["model+realization", "variability type", "domain", "frequency"],
Expand Down Expand Up @@ -389,9 +387,8 @@ def Avg_PS_DomFrq(d, frequency, ntd, dat, mip, frc):
else:
sys.exit("ERROR: frc " + frc + " is not defined!")

d_cdms = xr.DataArray.to_cdms2(d[0])
mask = cdutil.generateLandSeaMask(d_cdms)
mask = xr.DataArray.from_cdms2(mask)
# generate land sea mask
mask = create_land_sea_mask(d[0])

psdmfm = {}
for dom in domains:
Expand All @@ -405,8 +402,8 @@ def Avg_PS_DomFrq(d, frequency, ntd, dat, mip, frc):
dmask = d

dmask = dmask.to_dataset(name="ps")
dmask = dmask.bounds.add_bounds(axis="X", width=0.5)
dmask = dmask.bounds.add_bounds(axis="Y", width=0.5)
dmask = dmask.bounds.add_bounds(axis="X")
dmask = dmask.bounds.add_bounds(axis="Y")

if "50S50N" in dom:
am = dmask.sel(lat=slice(-50, 50)).spatial.average(
Expand Down
3 changes: 3 additions & 0 deletions pcmdi_metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .create_land_sea_mask import apply_landmask, create_land_sea_mask
from .create_target_grid import create_target_grid
from .sort_human import sort_human
Loading

0 comments on commit c6d5398

Please sign in to comment.