diff --git a/docs/source/conf.py b/docs/source/conf.py index 9cb6b53c3..f53862d13 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -57,6 +57,8 @@ "sphinx.ext.napoleon", "sphinx_rtd_theme", 'sphinx_click.ext', + "sphinx_tabs.tabs", + "sphinx_copybutton", ] intersphinx_mapping = { @@ -81,7 +83,7 @@ # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = 'en' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -140,7 +142,7 @@ # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. -htmlhelp_basename = 'rexdoc' +htmlhelp_basename = 'sup3rdoc' # -- Options for LaTeX output ------------------------------------------------ diff --git a/sup3r/__init__.py b/sup3r/__init__.py index f2ba4e292..8e37520df 100644 --- a/sup3r/__init__.py +++ b/sup3r/__init__.py @@ -2,7 +2,9 @@ """Super Resolving Renewable Energy Resource Data (SUP3R)""" import os from sup3r.version import __version__ -import sup3r.cli # import sets up CLI commands +# Next import sets up CLI commands +# This line could be "import sup3r.cli" but that breaks sphinx as of 12/11/2023 +from sup3r.cli import main __author__ = """Brandon Benton""" __email__ = "brandon.benton@nrel.gov" diff --git a/sup3r/bias/__init__.py b/sup3r/bias/__init__.py index e1f477055..944c9f864 100644 --- a/sup3r/bias/__init__.py +++ b/sup3r/bias/__init__.py @@ -2,4 +2,5 @@ """Bias calculation and correction modules.""" from .bias_transforms import (global_linear_bc, local_linear_bc, monthly_local_linear_bc) -from .bias_calc import LinearCorrection, MonthlyLinearCorrection +from .bias_calc import (LinearCorrection, MonthlyLinearCorrection, + SkillAssessment) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 2ac31677f..c20f25b23 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -19,6 +19,7 @@ from scipy.spatial import KDTree import sup3r.preprocessing.data_handling +from sup3r.preprocessing.data_handling.base import DataHandler from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.utilities import nn_fill_array @@ -36,10 +37,12 @@ def __init__(self, bias_fps, base_dset, bias_feature, + distance_upper_bound=None, target=None, shape=None, base_handler='Resource', bias_handler='DataHandlerNCforCC', + base_handler_kwargs=None, bias_handler_kwargs=None, decimals=None): """ @@ -60,6 +63,11 @@ def __init__(self, bias_feature : str This is the biased feature from bias_fps to retrieve. This should be a single feature name corresponding to base_dset + distance_upper_bound : float + Upper bound on the nearest neighbor distance in decimal degrees. + This should be the approximate resolution of the low-resolution + bias data. None (default) will calculate this based on the median + distance between points in bias_fps target : tuple (lat, lon) lower left corner of raster to retrieve from bias_fps. If None then the lower left corner of the full domain will be used. @@ -67,11 +75,17 @@ def __init__(self, (rows, cols) grid size to retrieve from bias_fps. If None then the full domain shape will be used. base_handler : str - Name of rex resource handler class to be retrieved from the rex - library. + Name of rex resource handler or sup3r.preprocessing.data_handling + class to be retrieved from the rex/sup3r library. If a + sup3r.preprocessing.data_handling class is used, all data will be + loaded in this class' initialization and the subsequent bias + calculation will be done in serial bias_handler : str Name of the bias data handler class to be retrieved from the sup3r.preprocessing.data_handling library. + base_handler_kwargs : dict | None + Optional kwargs to send to the initialization of the base_handler + class bias_handler_kwargs : dict | None Optional kwargs to send to the initialization of the bias_handler class @@ -92,36 +106,59 @@ def __init__(self, self.target = target self.shape = shape self.decimals = decimals - bias_handler_kwargs = bias_handler_kwargs or {} + self.base_handler_kwargs = base_handler_kwargs or {} + self.bias_handler_kwargs = bias_handler_kwargs or {} + self.bad_bias_gids = [] + self._distance_upper_bound = distance_upper_bound if isinstance(self.base_fps, str): self.base_fps = sorted(glob(self.base_fps)) if isinstance(self.bias_fps, str): self.bias_fps = sorted(glob(self.bias_fps)) - self.base_handler = getattr(rex, base_handler) + base_sup3r_handler = getattr(sup3r.preprocessing.data_handling, + base_handler, None) + base_rex_handler = getattr(rex, base_handler, None) + + if base_rex_handler is not None: + self.base_handler = base_rex_handler + self.base_dh = self.base_handler(self.base_fps[0], + **self.base_handler_kwargs) + elif base_sup3r_handler is not None: + self.base_handler = base_sup3r_handler + self.base_handler_kwargs['features'] = [self.base_dset] + self.base_dh = self.base_handler(self.base_fps, + **self.base_handler_kwargs) + msg = ('Base data handler opened with a sup3r DataHandler class ' + 'must load cached data!') + assert self.base_dh.data is not None, msg + else: + msg = f'Could not retrieve "{base_handler}" from sup3r or rex!' + logger.error(msg) + raise RuntimeError(msg) + self.bias_handler = getattr(sup3r.preprocessing.data_handling, bias_handler) - - with self.base_handler(self.base_fps[0]) as res: - self.base_meta = res.meta - self.base_tree = KDTree(self.base_meta[['latitude', 'longitude']]) - + self.base_meta = self.base_dh.meta self.bias_dh = self.bias_handler(self.bias_fps, [self.bias_feature], target=self.target, shape=self.shape, val_split=0.0, - **bias_handler_kwargs) + **self.bias_handler_kwargs) lats = self.bias_dh.lat_lon[..., 0].flatten() - lons = self.bias_dh.lat_lon[..., 1].flatten() - self.bias_meta = pd.DataFrame({'latitude': lats, 'longitude': lons}) + self.bias_meta = self.bias_dh.meta self.bias_ti = self.bias_dh.time_index raster_shape = self.bias_dh.lat_lon[..., 0].shape - self.bias_tree = KDTree(self.bias_meta[['latitude', 'longitude']]) + bias_lat_lon = self.bias_meta[['latitude', 'longitude']].values + self.bias_tree = KDTree(bias_lat_lon) self.bias_gid_raster = np.arange(lats.size) self.bias_gid_raster = self.bias_gid_raster.reshape(raster_shape) + self.nn_dist, self.nn_ind = self.bias_tree.query( + self.base_meta[['latitude', 'longitude']], k=1, + distance_upper_bound=self.distance_upper_bound) + self.out = None self._init_out() logger.info('Finished initializing DataRetrievalBase.') @@ -144,6 +181,19 @@ def meta(self): 'version_record': VERSION_RECORD} return meta + @property + def distance_upper_bound(self): + """Maximum distance (float) to map high-resolution data from exo_source + to the low-resolution file_paths input.""" + if self._distance_upper_bound is None: + diff = np.diff(self.bias_meta[['latitude', 'longitude']].values, + axis=0) + diff = np.max(np.median(diff, axis=0)) + self._distance_upper_bound = diff + logger.info('Set distance upper bound to {:.4f}' + .format(self._distance_upper_bound)) + return self._distance_upper_bound + @staticmethod def compare_dists(base_data, bias_data, adder=0, scalar=1): """Compare two distributions using the two-sample Kolmogorov-Smirnov. @@ -238,7 +288,7 @@ def get_bias_gid(self, coord): bias_gid = self.bias_gid_raster.flatten()[i] return bias_gid, d - def get_base_gid(self, bias_gid, knn): + def get_base_gid(self, bias_gid): """Get one or more base gid(s) corresponding to a bias gid. Parameters @@ -247,36 +297,34 @@ def get_base_gid(self, bias_gid, knn): gid of the data to retrieve in the bias data source raster data. The gids for this data source are the enumerated indices of the flattened coordinate array. - knn : int - Number of nearest neighbors to aggregate from the base data when - comparing to a single site from the bias data. Returns ------- dist : np.ndarray - Array of nearest neighbor distances with length == knn + Array of nearest neighbor distances with length equal to the number + of high-resolution baseline gids that map to the low resolution + bias gid pixel. base_gid : np.ndarray Array of base gids that are the nearest neighbors of bias_gid with - length == knn + length equal to the number of high-resolution baseline gids that + map to the low resolution bias gid pixel. """ - coord = self.bias_meta.loc[bias_gid, ['latitude', 'longitude']] - dist, base_gid = self.base_tree.query(coord, k=knn) + base_gid = np.where(self.nn_ind == bias_gid)[0] + dist = self.nn_dist[base_gid] return dist, base_gid - def get_data_pair(self, coord, knn, daily_reduction='avg'): + def get_data_pair(self, coord, daily_reduction='avg'): """Get base and bias data observations based on a single bias gid. Parameters ---------- coord : tuple (lat, lon) to get data for. - knn : int - Number of nearest neighbors to aggregate from the base data when - comparing to a single site from the bias data. daily_reduction : None | str Option to do a reduction of the hourly+ source base data to daily data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), or "min" (daily min) + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) Returns ------- @@ -287,12 +335,13 @@ def get_data_pair(self, coord, knn, daily_reduction='avg'): 1D array of temporal data at the requested gid. base_dist : np.ndarray Array of nearest neighbor distances from coord to the base data - sites with length == knn + sites with length equal to the number of high-resolution baseline + gids that map to the low resolution bias gid pixel. bias_dist : Float Nearest neighbor distance from coord to the bias data site """ bias_gid, bias_dist = self.get_bias_gid(coord) - base_dist, base_gid = self.get_base_gid(bias_gid, knn) + base_dist, base_gid = self.get_base_gid(bias_gid) bias_data = self.get_bias_data(bias_gid) base_data = self.get_base_data(self.base_fps, self.base_dset, @@ -343,8 +392,10 @@ def get_base_data(cls, base_dset, base_gid, base_handler, + base_handler_kwargs=None, daily_reduction='avg', - decimals=None): + decimals=None, + base_dh_inst=None): """Get data from the baseline data source, possibly for many high-res base gids corresponding to a single coarse low-res bias gid. @@ -360,20 +411,29 @@ def get_base_data(cls, One or more spatial gids to retrieve from base_fps. The data will be spatially averaged across all of these sites. base_handler : rex.Resource - A rex data handler similar to rex.Resource + A rex data handler similar to rex.Resource or sup3r.DataHandler + classes (if using the latter, must also input base_dh_inst) + base_handler_kwargs : dict | None + Optional kwargs to send to the initialization of the base_handler + class daily_reduction : None | str Option to do a reduction of the hourly+ source base data to daily data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), or "min" (daily min) + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) decimals : int | None Option to round bias and base data to this number of decimals, this gets passed to np.around(). If decimals is negative, it specifies the number of positions to the left of the decimal point. + base_dh_inst : sup3r.DataHandler + Instantiated DataHandler class that has already loaded the base + data (required if base files are .nc and are not being opened by a + rex Resource handler). Returns ------- - out : np.ndarray + out_data : np.ndarray 1D array of base data spatially averaged across the base_gid input and possibly daily-averaged or min/max'd as well. out_ti : pd.DatetimeIndex @@ -381,37 +441,82 @@ def get_base_data(cls, output data. """ - out = [] + out_data = [] out_ti = [] - for fp in base_fps: - with base_handler(fp) as res: - base_ti = res.time_index - - base_data, base_cs_ghi = cls._read_base_data( - res, base_dset, base_gid) - if daily_reduction is not None: - base_data = cls._reduce_base_data( - base_ti, - base_data, - base_cs_ghi, - base_dset, - daily_reduction, - ) - base_ti = np.array(sorted(set(base_ti.date))) + all_cs_ghi = [] + base_handler_kwargs = base_handler_kwargs or {} - out.append(base_data) - out_ti.append(base_ti) + if issubclass(base_handler, DataHandler) and base_dh_inst is None: + msg = ('The method `get_base_data()` is only to be used with ' + '`base_handler` as a `sup3r.DataHandler` subclass if ' + '`base_dh_inst` is also provided!') + logger.error(msg) + raise RuntimeError(msg) - out = np.hstack(out) + if issubclass(base_handler, DataHandler) and base_dh_inst is not None: + out_ti = base_dh_inst.time_index + out_data = cls._read_base_sup3r_data(base_dh_inst, base_dset, + base_gid) + all_cs_ghi = np.ones(len(out_data), dtype=np.float32) * np.nan + else: + for fp in base_fps: + with base_handler(fp, **base_handler_kwargs) as res: + base_ti = res.time_index + temp_out = cls._read_base_rex_data(res, base_dset, + base_gid) + base_data, base_cs_ghi = temp_out + + out_data.append(base_data) + out_ti.append(base_ti) + all_cs_ghi.append(base_cs_ghi) + + out_data = np.hstack(out_data) + out_ti = pd.DatetimeIndex(np.hstack(out_ti)) + all_cs_ghi = np.hstack(all_cs_ghi) + + if daily_reduction is not None: + out_data, out_ti = cls._reduce_base_data(out_ti, + out_data, + all_cs_ghi, + base_dset, + daily_reduction) if decimals is not None: - out = np.around(out, decimals=decimals) + out_data = np.around(out_data, decimals=decimals) + + return out_data, out_ti + + @staticmethod + def _read_base_sup3r_data(dh, base_dset, base_gid): + """Read baseline data from a sup3r DataHandler + + Parameters + ---------- + dh : sup3r.DataHandler + sup3r DataHandler that is an open file handler of the base file(s) + base_dset : str + A single dataset from the base_fps to retrieve. + base_gid : int | np.ndarray + One or more spatial gids to retrieve from base_fps. The data will + be spatially averaged across all of these sites. - return out, pd.DatetimeIndex(np.hstack(out_ti)) + Returns + ------- + base_data : np.ndarray + 1D array of base data spatially averaged across the base_gid input + """ + idf = dh.features.index(base_dset) + gid_raster = np.arange(len(dh.meta)) + gid_raster = gid_raster.reshape(dh.shape[:2]) + idy, idx = np.where(np.isin(gid_raster, base_gid)) + base_data = dh.data[idy, idx, :, idf] + assert base_data.shape[0] == len(base_gid) + assert base_data.shape[1] == len(dh.time_index) + return base_data.mean(axis=0) @staticmethod - def _read_base_data(res, base_dset, base_gid): - """Read baseline data from the resource handler with extra logic for + def _read_base_rex_data(res, base_dset, base_gid): + """Read baseline data from a rex resource handler with extra logic for special datasets (e.g. u/v wind components or clearsky_ratio) Parameters @@ -429,11 +534,15 @@ def _read_base_data(res, base_dset, base_gid): ------- base_data : np.ndarray 1D array of base data spatially averaged across the base_gid input - base_cs_ghi : np.ndarray | None + base_cs_ghi : np.ndarray If base_dset == "clearsky_ratio", the base_data array is GHI and - this base_cs_ghi is clearsky GHI. Otherwise this is None + this base_cs_ghi is clearsky GHI. Otherwise this is an array with + same length as base_data but full of np.nan """ + msg = '`res` input must not be a `DataHandler` subclass!' + assert not issubclass(res.__class__, DataHandler), msg + base_cs_ghi = None if base_dset.startswith(('U_', 'V_')): @@ -460,6 +569,9 @@ def _read_base_data(res, base_dset, base_gid): if base_cs_ghi is not None: base_cs_ghi = np.nanmean(base_cs_ghi, axis=1) + if base_cs_ghi is None: + base_cs_ghi = np.ones(len(base_data), dtype=np.float32) * np.nan + return base_data, base_cs_ghi @staticmethod @@ -474,45 +586,66 @@ def _reduce_base_data(base_ti, base_data, base_cs_ghi, base_dset, Time index associated with base_data base_data : np.ndarray 1D array of base data spatially averaged across the base_gid input - base_cs_ghi : np.ndarray | None + base_cs_ghi : np.ndarray If base_dset == "clearsky_ratio", the base_data array is GHI and - this base_cs_ghi is clearsky GHI. Otherwise this is None + this base_cs_ghi is clearsky GHI. Otherwise this is an array with + same length as base_data but full of np.nan base_dset : str A single dataset from the base_fps to retrieve. daily_reduction : str Option to do a reduction of the hourly+ source base data to daily data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), or "min" (daily min) + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) Returns ------- base_data : np.ndarray 1D array of base data spatially averaged across the base_gid input and possibly daily-averaged or min/max'd as well. + daily_ti : pd.DatetimeIndex + Daily DatetimeIndex corresponding to the daily base_data """ if daily_reduction is None: return base_data - slices = [ - np.where(base_ti.date == date) - for date in sorted(set(base_ti.date)) - ] + daily_ti = pd.DatetimeIndex(sorted(set(base_ti.date))) + df = pd.DataFrame({'date': base_ti.date, + 'base_data': base_data, + 'base_cs_ghi': base_cs_ghi}) + + cs_ratio = (daily_reduction.lower() in ('avg', 'average', 'mean') + and base_dset == 'clearsky_ratio') + + if cs_ratio: + daily_ghi = df.groupby('date').sum()['base_data'].values + daily_cs_ghi = df.groupby('date').sum()['base_cs_ghi'].values + base_data = daily_ghi / daily_cs_ghi + msg = ('Could not calculate daily average "clearsky_ratio" with ' + 'base_data and base_cs_ghi inputs: \n{}, \n{}' + .format(base_data, base_cs_ghi)) + assert not np.isnan(base_data).any(), msg - if base_dset == 'clearsky_ratio' and daily_reduction.lower() == 'avg': - base_data = np.array( - [base_data[s0].sum() / base_cs_ghi[s0].sum() for s0 in slices]) + elif daily_reduction.lower() in ('avg', 'average', 'mean'): + base_data = df.groupby('date').mean()['base_data'].values - elif daily_reduction.lower() == 'avg': - base_data = np.array([base_data[s0].mean() for s0 in slices]) + elif daily_reduction.lower() in ('max', 'maximum'): + base_data = df.groupby('date').max()['base_data'].values - elif daily_reduction.lower() == 'max': - base_data = np.array([base_data[s0].max() for s0 in slices]) + elif daily_reduction.lower() in ('min', 'minimum'): + base_data = df.groupby('date').min()['base_data'].values - elif daily_reduction.lower() == 'min': - base_data = np.array([base_data[s0].min() for s0 in slices]) + elif daily_reduction.lower() in ('sum', 'total'): + base_data = df.groupby('date').sum()['base_data'].values - return base_data + msg = (f'Daily reduced base data shape {base_data.shape} does not ' + f'match daily time index shape {daily_ti.shape}, ' + 'something went wrong!') + assert len(base_data.shape) == 1, msg + assert base_data.shape == daily_ti.shape, msg + + return base_data, daily_ti class LinearCorrection(DataRetrievalBase): @@ -593,7 +726,8 @@ def _run_single(cls, base_handler, daily_reduction, bias_ti, - decimals): + decimals, + base_dh_inst=None): """Find the nominal scalar + adder combination to bias correct data at a single site""" @@ -602,7 +736,8 @@ def _run_single(cls, base_gid, base_handler, daily_reduction=daily_reduction, - decimals=decimals) + decimals=decimals, + base_dh_inst=base_dh_inst) out = cls.get_linear_correction(bias_data, base_data, bias_feature, base_dset) @@ -651,6 +786,10 @@ def fill_and_smooth(self, like: bias_data * scalar + adder. Each value is of shape (lat, lon, time). """ + if len(self.bad_bias_gids) > 0: + logger.info('Found {} bias gids that are out of bounds: {}' + .format(len(self.bad_bias_gids), self.bad_bias_gids)) + for key, arr in out.items(): nan_mask = np.isnan(arr[..., 0]) for idt in range(arr.shape[-1]): @@ -661,6 +800,9 @@ def fill_and_smooth(self, and fill_extend) or smooth_interior > 0 if needs_fill: + logger.info('Filling NaN values outside of valid spatial ' + 'extent for dataset "{}" for timestep {}' + .format(key, idt)) arr_smooth = nn_fill_array(arr_smooth) arr_smooth_int = arr_smooth_ext = arr_smooth @@ -714,8 +856,6 @@ def write_outputs(self, fp_out, out): 'Wrote scalar adder factors to file: {}'.format(fp_out)) def run(self, - knn, - threshold=0.6, fp_out=None, max_workers=None, daily_reduction='avg', @@ -727,14 +867,6 @@ def run(self, Parameters ---------- - knn : int - Number of nearest neighbors to aggregate from the base data when - comparing to a single site from the bias data. - threshold : float - If the bias data coordinate is on average further from the base - data coordinates than this threshold, no bias correction factors - will be calculated directly and will just be filled from nearest - neighbor (if fill_extend=True, else it will be nan). fp_out : str | None Optional .h5 output file to write scalar and adder arrays. max_workers : int @@ -743,15 +875,16 @@ def run(self, daily_reduction : None | str Option to do a reduction of the hourly+ source base data to daily data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), or "min" (daily min) + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) fill_extend : bool - Flag to fill data past threshold using spatial nearest neighbor. If - False, the extended domain will be left as NaN. + Flag to fill data past distance_upper_bound using spatial nearest + neighbor. If False, the extended domain will be left as NaN. smooth_extend : float Option to smooth the scalar/adder data outside of the spatial - domain set by the threshold input. This alleviates the weird seams - far from the domain of interest. This value is the standard - deviation for the gaussian_filter kernel + domain set by the distance_upper_bound input. This alleviates the + weird seams far from the domain of interest. This value is the + standard deviation for the gaussian_filter kernel smooth_interior : float Option to smooth the scalar/adder data within the valid spatial domain. This can reduce the affect of extreme values within @@ -770,14 +903,22 @@ def run(self, logger.info('Initialized scalar / adder with shape: {}' .format(self.bias_gid_raster.shape)) + self.bad_bias_gids = [] + + # sup3r DataHandler opening base files will load all data in parallel + # during the init and should not be passed in parallel to workers + if isinstance(self.base_dh, DataHandler): + max_workers = 1 + if max_workers == 1: logger.debug('Running serial calculation.') - for i, (bias_gid, row) in enumerate(self.bias_meta.iterrows()): + for i, bias_gid in enumerate(self.bias_meta.index): raster_loc = np.where(self.bias_gid_raster == bias_gid) - coord = row[['latitude', 'longitude']] - dist, base_gid = self.base_tree.query(coord, k=knn) + _, base_gid = self.get_base_gid(bias_gid) - if np.mean(dist) < threshold: + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: bias_data = self.get_bias_data(bias_gid) single_out = self._run_single( bias_data, @@ -789,6 +930,7 @@ def run(self, daily_reduction, self.bias_ti, self.decimals, + base_dh_inst=self.base_dh, ) for key, arr in single_out.items(): self.out[key][raster_loc] = arr @@ -802,14 +944,14 @@ def run(self, max_workers)) with ProcessPoolExecutor(max_workers=max_workers) as exe: futures = {} - for bias_gid, bias_row in self.bias_meta.iterrows(): + for bias_gid in self.bias_meta.index: raster_loc = np.where(self.bias_gid_raster == bias_gid) - coord = bias_row[['latitude', 'longitude']] - dist, base_gid = self.base_tree.query(coord, k=knn) + _, base_gid = self.get_base_gid(bias_gid) - if np.mean(dist) < threshold: + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: bias_data = self.get_bias_data(bias_gid) - future = exe.submit( self._run_single, bias_data, @@ -863,7 +1005,8 @@ def _run_single(cls, base_handler, daily_reduction, bias_ti, - decimals): + decimals, + base_dh_inst=None): """Find the nominal scalar + adder combination to bias correct data at a single site""" @@ -872,7 +1015,8 @@ def _run_single(cls, base_gid, base_handler, daily_reduction=daily_reduction, - decimals=decimals) + decimals=decimals, + base_dh_inst=base_dh_inst) base_arr = np.full(cls.NT, np.nan, dtype=np.float32) out = {} @@ -1031,13 +1175,14 @@ def _run_skill_eval(cls, bias_data, base_data, bias_feature, base_dset): @classmethod def _run_single(cls, bias_data, base_fps, bias_feature, base_dset, base_gid, base_handler, daily_reduction, bias_ti, - decimals): + decimals, base_dh_inst=None): """Do a skill assessment at a single site""" base_data, base_ti = cls.get_base_data(base_fps, base_dset, base_gid, base_handler, daily_reduction=daily_reduction, - decimals=decimals) + decimals=decimals, + base_dh_inst=base_dh_inst) arr = np.full(cls.NT, np.nan, dtype=np.float32) out = {f'bias_{bias_feature}_mean_monthly': arr.copy(), diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 7e63858cf..bd9428aff 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd -import rioxarray import xarray as xr from rex import Resource from scipy.interpolate import interp1d @@ -116,6 +115,11 @@ def convert_month_height_tif(self, month, height): os.remove(outfile) if not os.path.exists(outfile) or self.overwrite: + try: + import rioxarray + except ImportError as e: + msg = 'Need special installation of "rioxarray" to run this!' + raise ImportError(msg) from e tmp = rioxarray.open_rasterio(infile) ds = tmp.to_dataset("band") ds = ds.rename( diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 4896b3384..b0700c017 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -5,7 +5,9 @@ import pickle import shutil from abc import ABC, abstractmethod +from warnings import warn +import pandas as pd import numpy as np from rex import Resource from rex.utilities.solar_position import SolarPosition @@ -15,7 +17,8 @@ from sup3r.postprocessing.file_handling import OutputHandler from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.utilities.utilities import generate_random_string, get_source_type +from sup3r.utilities.utilities import (generate_random_string, get_source_type, + nn_fill_array) logger = logging.getLogger(__name__) @@ -32,7 +35,6 @@ def __init__(self, exo_source, s_enhance, t_enhance, - s_agg_factor, t_agg_factor, target=None, shape=None, @@ -43,6 +45,7 @@ def __init__(self, cache_data=True, cache_dir='./exo_cache/', ti_workers=1, + distance_upper_bound=None, res_kwargs=None): """Parameters ---------- @@ -53,9 +56,14 @@ def __init__(self, typically low-res WRF output or GCM netcdf data files that is source low-resolution data intended to be sup3r resolved. exo_source : str - Filepath to source wtk or nsrdb file to get hi-res (2km or 4km) - elevation data from which will be mapped to the enhanced grid of - the file_paths input + Filepath to source data file to get hi-res elevation data from + which will be mapped to the enhanced grid of the file_paths input. + Pixels from this exo_source will be mapped to their nearest low-res + pixel in the file_paths input. Accordingly, exo_source should be a + significantly higher resolution than file_paths. Warnings will be + raised if the low-resolution pixels in file_paths do not have + unique nearest pixels from exo_source. File format can be .h5 for + TopoExtractH5 or .nc for TopoExtractNC s_enhance : int Factor by which the Sup3rGan model will enhance the spatial dimensions of low resolution data from file_paths input. For @@ -68,20 +76,14 @@ def __init__(self, example, if getting sza data, file_paths has hourly data, and t_enhance is 4, this class will output a sza raster corresponding to the file_paths temporally enhanced 4x to 15 min - s_agg_factor : int - Factor by which to aggregate the exo_source data to the resolution - of the file_paths input enhanced by s_enhance. For example, if - getting topography data, file_paths have 100km data, and s_enhance - is 4 resulting in a desired resolution of ~25km and topo_source_h5 - has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 - 4km cells are averaged to the ~25km enhanced grid. t_agg_factor : int - Factor by which to aggregate the exo_source data to the resolution - of the file_paths input enhanced by t_enhance. For example, if - getting sza data, file_paths have hourly data, and t_enhance - is 4 resulting in a desired resolution of 5 min and exo_source - has a resolution of 5 min, the t_agg_factor should be 4 so that - every fourth timestep in the exo_source data is skipped. + Factor by which to aggregate / subsample the exo_source data to the + resolution of the file_paths input enhanced by t_enhance. For + example, if getting sza data, file_paths have hourly data, and + t_enhance is 4 resulting in a target resolution of 15 min and + exo_source has a resolution of 5 min, the t_agg_factor should be 3 + so that only timesteps that are a multiple of 15min are selected + e.g., [0, 5, 10, 15, 20, 25, 30][slice(0, None, 3)] = [0, 15, 30] target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -118,6 +120,10 @@ def __init__(self, parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. + distance_upper_bound : float | None + Maximum distance to map high-resolution data from exo_source to the + low-resolution file_paths input. None (default) will calculate this + based on the median distance between points in exo_source res_kwargs : dict | None Dictionary of kwargs passed to lowest level resource handler. e.g. xr.open_dataset(file_paths, **res_kwargs) @@ -128,13 +134,13 @@ def __init__(self, self._exo_source = exo_source self._s_enhance = s_enhance self._t_enhance = t_enhance - self._s_agg_factor = s_agg_factor self._t_agg_factor = t_agg_factor self._tree = None self._hr_lat_lon = None self._source_lat_lon = None self._hr_time_index = None self._src_time_index = None + self._distance_upper_bound = distance_upper_bound self.cache_data = cache_data self.cache_dir = cache_dir self.temporal_slice = temporal_slice @@ -179,8 +185,7 @@ def __init__(self, def source_data(self): """Get the 1D array of source data from the exo_source_h5""" - def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, - t_agg_factor): + def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor): """Get cache file name Parameters @@ -193,9 +198,6 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_enhance : int Temporal enhancement for this exogeneous data step (cumulative for all model steps up to the current step). - s_agg_factor : int - Factor by which to aggregate the exo_source data to the spatial - resolution of the file_paths input enhanced by s_enhance. t_agg_factor : int Factor by which to aggregate the exo_source data to the temporal resolution of the file_paths input enhanced by t_enhance. @@ -210,7 +212,7 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, or self.temporal_slice.stop is None else self.temporal_slice.stop - self.temporal_slice.start) fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' - fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'_tagg{t_agg_factor}_{s_enhance}x_' fn += f'{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') @@ -233,14 +235,10 @@ def source_temporal_slice(self): @property def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data for the exo source""" - if self._source_lat_lon is None: - src_enhance = int(np.sqrt(self._s_agg_factor)) - src_shape = (self.hr_shape[0] * src_enhance, - self.hr_shape[1] * src_enhance) - self._source_lat_lon = OutputHandler.get_lat_lon( - self.lr_lat_lon, src_shape).reshape((-1, 2)) - return self._source_lat_lon + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + with Resource(self._exo_source) as res: + source_lat_lon = res.lat_lon + return source_lat_lon @property def lr_shape(self): @@ -259,7 +257,7 @@ def hr_shape(self): def lr_lat_lon(self): """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last dimension. This corresponds to the raw - low-resolution meta data from the file_paths input. + meta data from the file_paths input. Returns ------- @@ -271,7 +269,7 @@ def lr_lat_lon(self): def hr_lat_lon(self): """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last dimension. This corresponds to the - enhanced high-res meta data from the file_paths input * s_enhance. + enhanced meta data from the file_paths input * s_enhance. Returns ------- @@ -308,22 +306,34 @@ def hr_time_index(self): self._hr_time_index = self.input_handler.time_index return self._hr_time_index + @property + def distance_upper_bound(self): + """Maximum distance (float) to map high-resolution data from exo_source + to the low-resolution file_paths input.""" + if self._distance_upper_bound is None: + diff = np.diff(self.source_lat_lon, axis=0) + diff = np.max(np.median(diff, axis=0)) + self._distance_upper_bound = diff + logger.info('Set distance upper bound to {:.4f}' + .format(self._distance_upper_bound)) + return self._distance_upper_bound + @property def tree(self): - """Get the KDTree built on the source lat lon data""" + """Get the KDTree built on the target lat lon data from the file_paths + input with s_enhance""" if self._tree is None: - self._tree = KDTree(self.source_lat_lon) + lat = self.hr_lat_lon[..., 0].flatten() + lon = self.hr_lat_lon[..., 1].flatten() + hr_meta = np.vstack((lat, lon)).T + self._tree = KDTree(hr_meta) return self._tree @property def nn(self): """Get the nearest neighbor indices""" - ll2 = np.vstack( - (self.hr_lat_lon[:, :, 0].flatten(), - self.hr_lat_lon[:, :, 1].flatten())).T - _, nn = self.tree.query(ll2, k=self._s_agg_factor) - if len(nn.shape) == 1: - nn = np.expand_dims(nn, 1) + _, nn = self.tree.query(self.source_lat_lon, k=1, + distance_upper_bound=self.distance_upper_bound) return nn @property @@ -335,7 +345,6 @@ def data(self): cache_fp = self.get_cache_file(feature=self.__class__.__name__, s_enhance=self._s_enhance, t_enhance=self._t_enhance, - s_agg_factor=self._s_agg_factor, t_agg_factor=self._t_agg_factor) tmp_fp = cache_fp + f'.{generate_random_string(10)}.tmp' if os.path.exists(cache_fp): @@ -355,27 +364,18 @@ def data(self): return data[..., np.newaxis] + @abstractmethod def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * - t_enhance). The shape is (lats, lons, temporal, 1) + t_enhance). The shape is (lats, lons, temporal) """ - nn = self.nn - hr_data = [] - for j in range(self._s_agg_factor): - out = self.source_data[nn[:, j], self.source_temporal_slice] - out = out.reshape(self.hr_shape) - hr_data.append(out[..., np.newaxis]) - hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) - logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data @classmethod def get_exo_raster(cls, file_paths, s_enhance, t_enhance, - s_agg_factor, t_agg_factor, exo_source=None, target=None, @@ -407,13 +407,6 @@ class will output a topography raster corresponding to the example, if getting sza data, file_paths has hourly data, and t_enhance is 4, this class will output a sza raster corresponding to the file_paths temporally enhanced 4x to 15 min - s_agg_factor : int - Factor by which to aggregate the exo_source data to the resolution - of the file_paths input enhanced by s_enhance. For example, if - getting topography data, file_paths have 100km data, and s_enhance - is 4 resulting in a desired resolution of ~25km and topo_source_h5 - has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 - 4km cells are averaged to the ~25km enhanced grid. t_agg_factor : int Factor by which to aggregate the exo_source data to the resolution of the file_paths input enhanced by t_enhance. For example, if @@ -467,7 +460,6 @@ class will output a topography raster corresponding to the exo = cls(file_paths, s_enhance, t_enhance, - s_agg_factor, t_agg_factor, exo_source=exo_source, target=target, @@ -491,13 +483,6 @@ def source_data(self): elev = res.get_meta_arr('elevation') return elev[:, np.newaxis] - @property - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" - with Resource(self._exo_source) as res: - source_lat_lon = res.lat_lon - return source_lat_lon - @property def source_time_index(self): """Time index of the source exo data""" @@ -509,20 +494,42 @@ def source_time_index(self): def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * - t_enhance). The shape is (lats, lons, temporal, 1) + t_enhance). The shape is (lats, lons, 1) """ - nn = self.nn - hr_data = [] - for j in range(self._s_agg_factor): - out = self.source_data[nn[:, j]] - out = out.reshape((*self.hr_shape[:-1], -1)) - hr_data.append(out[..., np.newaxis]) - hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) + + assert len(self.source_data.shape) == 2 + assert self.source_data.shape[1] == 1 + + df = pd.DataFrame({'topo': self.source_data.flatten(), + 'gid_target': self.nn}) + n_target = np.product(self.hr_shape[:-1]) + df = df[df['gid_target'] != n_target] + df = df.sort_values('gid_target') + df = df.groupby('gid_target').mean() + + missing = set(np.arange(n_target)) - set(df.index) + if any(missing): + msg = (f'{len(missing)} target pixels did not have unique ' + 'high-resolution source data to map from. If there are a ' + 'lot of target pixels missing source data this probably ' + 'means the source data is not high enough resolution. ' + 'Filling raster with NN.') + logger.warning(msg) + warn(msg) + temp_df = pd.DataFrame({'topo': np.nan}, index=sorted(missing)) + df = pd.concat((df, temp_df)).sort_index() + + hr_data = df['topo'].values.reshape(self.hr_shape[:-1]) + if np.isnan(hr_data).any(): + hr_data = nn_fill_array(hr_data) + + hr_data = np.expand_dims(hr_data, axis=-1) + logger.info('Finished mapping raster from {}'.format(self._exo_source)) + return hr_data - def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, - t_agg_factor): + def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor): """Get cache file name. This uses a time independent naming convention. Parameters @@ -535,9 +542,6 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_enhance : int Temporal enhancement for this exogeneous data step (cumulative for all model steps up to the current step). - s_agg_factor : int - Factor by which to aggregate the exo_source data to the spatial - resolution of the file_paths input enhanced by s_enhance. t_agg_factor : int Factor by which to aggregate the exo_source data to the temporal resolution of the file_paths input enhanced by t_enhance. @@ -548,7 +552,7 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, Name of cache file """ fn = f'exo_{feature}_{self.target}_{self.shape}' - fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'_tagg{t_agg_factor}_{s_enhance}x_' fn += f'{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') @@ -605,7 +609,7 @@ def source_data(self): def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * - t_enhance). The shape is (lats, lons, temporal, 1) + t_enhance). The shape is (lats, lons, temporal) """ hr_data = self.source_data.reshape(self.hr_shape) logger.info('Finished computing SZA data') diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 3149887e4..0c5519dd7 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,4 +1,5 @@ """Sup3r exogenous data handling""" +from inspect import signature import logging import re from typing import ClassVar @@ -244,9 +245,13 @@ def __init__(self, source. e.g. {'spatial': '4km', 'temporal': '60min'}. This is used only if agg factors are not provided in the steps list. source_file : str - Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or - 4km) data from which will be mapped to the enhanced grid of the - file_paths input + Filepath to source wtk, nsrdb, or netcdf file to get hi-res data + from which will be mapped to the enhanced grid of the file_paths + input. Pixels from this file will be mapped to their nearest + low-res pixel in the file_paths input. Accordingly, the input + should be a significantly higher resolution than file_paths. + Warnings will be raised if the low-resolution pixels in file_paths + do not have unique nearest pixels from this exo source data. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -573,21 +578,24 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, exo_handler = self.get_exo_handler(feature, self.source_file, self.exo_handler) - data = exo_handler(self.file_paths, - self.source_file, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_slice, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler, - cache_data=self.cache_data, - cache_dir=self.cache_dir, - res_kwargs=self.res_kwargs).data + kwargs = dict(file_paths=self.file_paths, + exo_source=self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler, + cache_data=self.cache_data, + cache_dir=self.cache_dir, + res_kwargs=self.res_kwargs) + sig = signature(exo_handler) + kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters} + data = exo_handler(**kwargs).data return data @classmethod diff --git a/sup3r/preprocessing/data_handling/h5_data_handling.py b/sup3r/preprocessing/data_handling/h5_data_handling.py index fca438c0c..8b4e945a2 100644 --- a/sup3r/preprocessing/data_handling/h5_data_handling.py +++ b/sup3r/preprocessing/data_handling/h5_data_handling.py @@ -137,7 +137,9 @@ def extract_feature(cls, try: fdata = handle[(feature, time_slice, *(raster_index.flatten(),))] except ValueError as e: - msg = f'{feature} cannot be extracted from source data' + hfeatures = cls.get_handle_features(file_paths) + msg = (f'Requested feature "{feature}" cannot be extracted from ' + f'source data that has handle features: {hfeatures}.') logger.exception(msg) raise ValueError(msg) from e diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py index 8a2ade1ba..cc2105b15 100644 --- a/sup3r/preprocessing/data_handling/nc_data_handling.py +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -281,7 +281,9 @@ def extract_feature(cls, time_slice) else: - msg = f'{feature} cannot be extracted from source data.' + hfeatures = cls.get_handle_features(file_paths) + msg = (f'Requested feature "{feature}" cannot be extracted from ' + f'source data that has handle features: {hfeatures}.') logger.exception(msg) raise ValueError(msg) diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index 609cec0aa..f83a8d74f 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1223,11 +1223,13 @@ def get_height(feature): height to use for interpolation in meters """ - height = re.search(r'\d+m', feature) - if height: - height = height.group(0).strip('m') - if not height.isdigit(): - height = None + height = None + if isinstance(feature, str): + height = re.search(r'\d+m', feature) + if height: + height = height.group(0).strip('m') + if not height.isdigit(): + height = None return height @staticmethod @@ -1244,11 +1246,13 @@ def get_pressure(feature): float | None pressure to use for interpolation in pascals """ - pressure = re.search(r'\d+pa', feature) - if pressure: - pressure = pressure.group(0).strip('pa') - if not pressure.isdigit(): - pressure = None + pressure = None + if isinstance(feature, str): + pressure = re.search(r'\d+pa', feature) + if pressure: + pressure = pressure.group(0).strip('pa') + if not pressure.isdigit(): + pressure = None return pressure @@ -1752,10 +1756,11 @@ def _exact_lookup(cls, feature): Matching feature registry entry. """ out = None - for k, v in cls.FEATURE_REGISTRY.items(): - if k.lower() == feature.lower(): - out = v - break + if isinstance(feature, str): + for k, v in cls.FEATURE_REGISTRY.items(): + if k.lower() == feature.lower(): + out = v + break return out @classmethod @@ -1774,10 +1779,11 @@ def _pattern_lookup(cls, feature): Matching feature registry entry. """ out = None - for k, v in cls.FEATURE_REGISTRY.items(): - if re.match(k.lower(), feature.lower()): - out = v - break + if isinstance(feature, str): + for k, v in cls.FEATURE_REGISTRY.items(): + if re.match(k.lower(), feature.lower()): + out = v + break return out @classmethod @@ -1816,7 +1822,7 @@ def _lookup(cls, out, feature, handle_features=None): if pressure is not None: out = out.split('(.*)')[0] + f'{pressure}pa' - return lambda x: [out] + return lambda x: [out] if isinstance(out, str) else out @classmethod def lookup(cls, feature, attr_name, handle_features=None): @@ -1851,7 +1857,6 @@ def lookup(cls, feature, attr_name, handle_features=None): return getattr(out, attr_name, None) elif attr_name == 'inputs': - return cls._lookup(out, feature, handle_features) @classmethod @@ -1873,11 +1878,11 @@ def get_inputs_recursive(cls, feature, handle_features): """ raw_features = [] method = cls.lookup(feature, 'inputs', handle_features=handle_features) - lower_handle_features = [f.lower() for f in handle_features] + low_handle_features = [f.lower() for f in handle_features] + vhf = cls.valid_handle_features([feature.lower()], low_handle_features) check1 = feature not in raw_features - check2 = (cls.valid_handle_features( - [feature.lower()], lower_handle_features) or method is None) + check2 = (vhf or method is None) if check1 and check2: raw_features.append(feature) @@ -1924,6 +1929,7 @@ def get_raw_feature_list(cls, features, handle_features): f'Requested features: {req}') logger.error(msg) raise ValueError(msg) + return raw_features @classmethod diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index cc5ef654d..d4822ef59 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -26,26 +26,12 @@ logger = logging.getLogger(__name__) -try: - import cdsapi - - CDS_API_CLIENT = cdsapi.Client() -except ImportError as e: - msg = f'Could not import cdsapi package. {e}' - logger.error(msg) - class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, and interpolation. """ - msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' - 'with a valid url and api key. Follow the instructions here: ' - 'https://cds.climate.copernicus.eu/api-how-to') - req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') - assert os.path.exists(req_file), msg - # variables available on a single level (e.g. surface) SFC_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', @@ -275,6 +261,26 @@ def prep_var_lists(self, variables): logger.warning(msg) warn(msg) + @staticmethod + def get_cds_client(): + """Get the copernicus climate data store (CDS) API object for ERA + downloads.""" + try: + import cdsapi + cds_api_client = cdsapi.Client() + except ImportError as e: + msg = f'Could not import cdsapi package. {e}' + logger.error(msg) + raise ImportError(msg) from e + + msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' + 'with a valid url and api key. Follow the instructions here: ' + 'https://cds.climate.copernicus.eu/api-how-to') + req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') + assert os.path.exists(req_file), msg + + return cds_api_client + def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 @@ -338,7 +344,8 @@ def download_file(cls, variables, time_dict, area, out_file, level_type, if level_type == 'pressure': entry['pressure_level'] = levels logger.info(f'Calling CDS-API with {entry}.') - CDS_API_CLIENT.retrieve( + cds_api_client = cls.get_cds_client() + cds_api_client.retrieve( f'reanalysis-era5-{level_type}-levels', entry, out_file) else: diff --git a/sup3r/utilities/plotting.py b/sup3r/utilities/plotting.py index a0d6f1c5e..61eac0270 100644 --- a/sup3r/utilities/plotting.py +++ b/sup3r/utilities/plotting.py @@ -2,7 +2,6 @@ """Utilities module for plotting data """ -import imageio import matplotlib from matplotlib import cm import matplotlib.pyplot as plt @@ -170,6 +169,12 @@ def make_movie(ntime, movieDir, movieName, fps=24): number of frame per second for the movie, by default 24 """ + try: + import imageio + except ImportError as e: + msg = f'Need extra installation to make movie "imageio": {e}' + raise ImportError(msg) from e + # initiate an empty list of "plotted" images myimages = [] # loops through available pngs diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 17a85e4fa..0e155ba45 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -1460,12 +1460,12 @@ def get_input_handler_class(file_paths, input_handler_name): def np_to_pd_times(times): - """Convert np.bytes_ times to DatetimeIndex + """Convert `np.bytes_` times to DatetimeIndex Parameters ---------- times : ndarray | list - List of np.bytes_ objects for time indices + List of `np.bytes_` objects for time indices Returns ------- diff --git a/sup3r/version.py b/sup3r/version.py index 7e7d07482..d50352360 100644 --- a/sup3r/version.py +++ b/sup3r/version.py @@ -1,3 +1,3 @@ """SUP3R Version""" -__version__ = '0.1.1' +__version__ = '0.1.2' diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 0f6c022ad..e0487ee9a 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -32,17 +32,20 @@ def test_smooth_interior_bc(): """Test linear bias correction with interior smoothing""" calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=False, max_workers=1) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=False, max_workers=1) og_scalar = out['rsds_scalar'] og_adder = out['rsds_adder'] nan_mask = np.isnan(og_scalar) assert np.isnan(og_adder[nan_mask]).all() calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_interior=0, - max_workers=1) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=True, smooth_interior=0, max_workers=1) scalar = out['rsds_scalar'] adder = out['rsds_adder'] # Make sure smooth_interior=0 does not change interior pixels @@ -53,9 +56,10 @@ def test_smooth_interior_bc(): # make sure smoothing affects the interior pixels but not the exterior calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_interior=1, - max_workers=1) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=True, smooth_interior=1, max_workers=1) smooth_scalar = out['rsds_scalar'] smooth_adder = out['rsds_adder'] @@ -69,25 +73,27 @@ def test_linear_bc(): """Test linear bias correction""" calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') # test a known in-bounds gid bias_gid = 5 - dist, base_gid = calc.get_base_gid(bias_gid, 1) + dist, base_gid = calc.get_base_gid(bias_gid) bias_data = calc.get_bias_data(bias_gid) base_data, _ = calc.get_base_data(calc.base_fps, calc.base_dset, base_gid, calc.base_handler, daily_reduction='avg') - bias_coord = calc.bias_meta.loc[bias_gid, ['latitude', 'longitude']] + bias_coord = calc.bias_meta.loc[[bias_gid], ['latitude', 'longitude']] base_coord = calc.base_meta.loc[base_gid, ['latitude', 'longitude']] true_dist = bias_coord.values - base_coord.values - true_dist = np.hypot(true_dist[0], true_dist[1]) + true_dist = np.hypot(true_dist[:, 0], true_dist[:, 1]) assert np.allclose(true_dist, dist) - assert true_dist < 0.1 + assert (true_dist < 0.5).all() # horiz res of bias data is ~0.7 deg true_scalar = base_data.std() / bias_data.std() true_adder = base_data.mean() - bias_data.mean() * true_scalar - out = calc.run(knn=1, threshold=0.6, fill_extend=False, max_workers=1) + out = calc.run(fill_extend=False, max_workers=1) scalar = out['rsds_scalar'] adder = out['rsds_adder'] @@ -106,11 +112,14 @@ def test_linear_bc(): assert np.isnan(adder[corner]) nan_mask = np.isnan(scalar) assert np.isnan(adder[nan_mask]).all() + assert len(calc.bad_bias_gids) > 0 # make sure the NN fill works for out-of-bounds pixels calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=True, max_workers=1) scalar = out['rsds_scalar'] adder = out['rsds_adder'] @@ -118,14 +127,16 @@ def test_linear_bc(): assert np.allclose(true_scalar, scalar[iloc]) assert np.allclose(true_adder, adder[iloc]) + assert len(calc.bad_bias_gids) > 0 assert not np.isnan(scalar[nan_mask]).any() assert not np.isnan(adder[nan_mask]).any() # make sure smoothing affects the out-of-bounds pixels but not the in-bound calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_extend=2, - max_workers=1) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=True, smooth_extend=2, max_workers=1) smooth_scalar = out['rsds_scalar'] smooth_adder = out['rsds_adder'] assert np.allclose(smooth_scalar[~nan_mask], scalar[~nan_mask]) @@ -135,9 +146,10 @@ def test_linear_bc(): # parallel test calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, smooth_extend=2, - max_workers=2) + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') + out = calc.run(fill_extend=True, smooth_extend=2, max_workers=2) par_scalar = out['rsds_scalar'] par_adder = out['rsds_adder'] assert np.allclose(smooth_scalar, par_scalar) @@ -148,28 +160,29 @@ def test_monthly_linear_bc(): """Test linear bias correction on a month-by-month basis""" calc = MonthlyLinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, bias_handler='DataHandlerNCforCC') # test a known in-bounds gid bias_gid = 5 - dist, base_gid = calc.get_base_gid(bias_gid, 1) + dist, base_gid = calc.get_base_gid(bias_gid) bias_data = calc.get_bias_data(bias_gid) base_data, base_ti = calc.get_base_data(calc.base_fps, calc.base_dset, base_gid, calc.base_handler, daily_reduction='avg') - bias_coord = calc.bias_meta.loc[bias_gid, ['latitude', 'longitude']] + bias_coord = calc.bias_meta.loc[[bias_gid], ['latitude', 'longitude']] base_coord = calc.base_meta.loc[base_gid, ['latitude', 'longitude']] true_dist = bias_coord.values - base_coord.values - true_dist = np.hypot(true_dist[0], true_dist[1]) + true_dist = np.hypot(true_dist[:, 0], true_dist[:, 1]) assert np.allclose(true_dist, dist) - assert true_dist < 0.1 + assert (true_dist < 0.5).all() # horiz res of bias data is ~0.7 deg base_data = base_data[:31] # just take Jan for testing bias_data = bias_data[:31] # just take Jan for testing true_scalar = base_data.std() / bias_data.std() true_adder = base_data.mean() - bias_data.mean() * true_scalar - out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1) + out = calc.run(fill_extend=True, max_workers=1) scalar = out['rsds_scalar'] adder = out['rsds_adder'] @@ -191,12 +204,13 @@ def test_monthly_linear_bc(): def test_linear_transform(): """Test the linear bc transform method""" calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, bias_handler='DataHandlerNCforCC') + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + bias_handler='DataHandlerNCforCC') lat_lon = calc.bias_dh.lat_lon with tempfile.TemporaryDirectory() as td: fp_out = os.path.join(td, 'bc.h5') - out = calc.run(knn=1, threshold=0.6, fill_extend=False, max_workers=1, - fp_out=fp_out) + out = calc.run(fill_extend=False, max_workers=1, fp_out=fp_out) scalar = out['rsds_scalar'] adder = out['rsds_adder'] test_data = np.ones_like(scalar) @@ -204,8 +218,7 @@ def test_linear_transform(): out = local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, out_range=None) - out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1, - fp_out=fp_out) + out = calc.run(fill_extend=True, max_workers=1, fp_out=fp_out) scalar = out['rsds_scalar'] adder = out['rsds_adder'] test_data = np.ones_like(scalar) @@ -235,7 +248,8 @@ def test_linear_transform(): def test_montly_linear_transform(): """Test the montly linear bc transform method""" calc = MonthlyLinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', - TARGET, SHAPE, + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, bias_handler='DataHandlerNCforCC') lat_lon = calc.bias_dh.lat_lon _, base_ti = calc.get_base_data(calc.base_fps, calc.base_dset, @@ -243,8 +257,7 @@ def test_montly_linear_transform(): daily_reduction='avg') with tempfile.TemporaryDirectory() as td: fp_out = os.path.join(td, 'bc.h5') - out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1, - fp_out=fp_out) + out = calc.run(fill_extend=True, max_workers=1, fp_out=fp_out) scalar = out['rsds_scalar'] adder = out['rsds_adder'] test_data = np.ones((scalar.shape[0], scalar.shape[1], len(base_ti))) @@ -278,19 +291,23 @@ def test_clearsky_ratio(): 'temporal_slice': [0, 30, 1]} calc = LinearCorrection(FP_NSRDB, FP_CC, 'clearsky_ratio', 'clearsky_ratio', - TARGET, SHAPE, + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, bias_handler_kwargs=bias_handler_kwargs, bias_handler='DataHandlerNCforCC') - out = calc.run(knn=1, threshold=100, fill_extend=False, max_workers=1) + out = calc.run(fill_extend=True, max_workers=1) assert not np.isnan(out['clearsky_ratio_scalar']).any() assert not np.isnan(out['clearsky_ratio_adder']).any() - assert (out['base_clearsky_ratio_mean'] > 0.3).all() - assert (out['base_clearsky_ratio_mean'] < 1.0).all() + base_cs = out['base_clearsky_ratio_mean'] + bias_cs = out['bias_clearsky_ratio_mean'] - assert (out['bias_clearsky_ratio_mean'] > 0.3).all() - assert (out['bias_clearsky_ratio_mean'] < 1.0).all() + assert (base_cs > 0.3).all() + assert (base_cs < 1.0).all() + + assert (base_cs > 0.3).all() + assert (bias_cs < 1.0).all() def test_fwp_integration(): @@ -446,26 +463,28 @@ def test_qa_integration(): def test_skill_assessment(): """Test the skill assessment of a climate model vs. historical data""" - calc = SkillAssessment(FP_NSRDB, FP_CC, 'ghi', 'rsds', TARGET, SHAPE, + calc = SkillAssessment(FP_NSRDB, FP_CC, 'ghi', 'rsds', + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, bias_handler='DataHandlerNCforCC') # test a known in-bounds gid bias_gid = 5 - dist, base_gid = calc.get_base_gid(bias_gid, 1) + dist, base_gid = calc.get_base_gid(bias_gid) bias_data = calc.get_bias_data(bias_gid) base_data, _ = calc.get_base_data(calc.base_fps, calc.base_dset, base_gid, calc.base_handler, daily_reduction='avg') - bias_coord = calc.bias_meta.loc[bias_gid, ['latitude', 'longitude']] + bias_coord = calc.bias_meta.loc[[bias_gid], ['latitude', 'longitude']] base_coord = calc.base_meta.loc[base_gid, ['latitude', 'longitude']] true_dist = bias_coord.values - base_coord.values - true_dist = np.hypot(true_dist[0], true_dist[1]) + true_dist = np.hypot(true_dist[:, 0], true_dist[:, 1]) assert np.allclose(true_dist, dist) - assert true_dist < 0.1 + assert (true_dist < 0.5).all() # horiz res of bias data is ~0.7 deg iloc = np.where(calc.bias_gid_raster == bias_gid) iloc += (0, ) - out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1) + out = calc.run(fill_extend=True, max_workers=1) base_mean = base_data.mean() bias_mean = bias_data.mean() @@ -476,3 +495,38 @@ def test_skill_assessment(): ks = stats.ks_2samp(base_data - base_mean, bias_data - bias_mean) assert np.allclose(out['rsds_ks_stat'][iloc], ks.statistic) assert np.allclose(out['rsds_ks_p'][iloc], ks.pvalue) + + +def test_nc_base_file(): + """Test a base file being a .nc like ERA5""" + calc = SkillAssessment(FP_CC, FP_CC, 'rsds', 'rsds', + target=TARGET, shape=SHAPE, + distance_upper_bound=0.7, + base_handler='DataHandlerNCforCC', + bias_handler='DataHandlerNCforCC') + + # test a known in-bounds gid + bias_gid = 5 + dist, base_gid = calc.get_base_gid(bias_gid) + assert dist == 0 + assert (calc.nn_dist == 0).all() + + with pytest.raises(RuntimeError) as exc: + calc.get_base_data(calc.base_fps, calc.base_dset, base_gid, + calc.base_handler, daily_reduction='avg') + + good_err = 'only to be used with `base_handler` as a `sup3r.DataHandler` ' + assert good_err in str(exc.value) + + # make sure this doesnt raise error now that calc.base_dh is provided + calc.get_base_data(calc.base_fps, calc.base_dset, + base_gid, calc.base_handler, + daily_reduction='avg', + base_dh_inst=calc.base_dh) + + out = calc.run(fill_extend=True, max_workers=1) + + assert (out['rsds_scalar'] == 1).all() + assert (out['rsds_adder'] == 0).all() + assert np.allclose(out['base_rsds_mean'], out['bias_rsds_mean']) + assert np.allclose(out['base_rsds_std'], out['bias_rsds_std']) diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 601d547b4..4caefd16d 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -9,6 +9,8 @@ from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling import ExogenousDataHandler +from test_utils_topo import make_topo_file + FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') FILE_PATHS = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'), @@ -37,8 +39,9 @@ def test_exo_cache(feature): 'combine_type': 'input', 'model': 0}) with TemporaryDirectory() as td: + fp_topo = make_topo_file(FILE_PATHS[0], td) base = ExogenousDataHandler(FILE_PATHS, feature, - source_file=FP_WTK, + source_file=fp_topo, steps=steps, target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC', diff --git a/tests/data_handling/test_utils_topo.py b/tests/data_handling/test_utils_topo.py index 294bede19..01826a754 100644 --- a/tests/data_handling/test_utils_topo.py +++ b/tests/data_handling/test_utils_topo.py @@ -1,11 +1,15 @@ # -*- coding: utf-8 -*- """pytests for topography utilities""" import os +import shutil +import tempfile +import pandas as pd import matplotlib.pyplot as plt import numpy as np import pytest -from scipy.spatial import KDTree +from rex import Resource +from rex import Outputs from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling.exo_extraction import ( @@ -21,83 +25,126 @@ WRF_SHAPE = (8, 8) -@pytest.mark.parametrize('agg_factor', [1, 4, 8]) -def test_topo_extraction_h5(agg_factor, plot=False): - """Test the spatial enhancement of a test grid and then the lookup of the - elevation data to a reference WTK file (also the same file for the test)""" - te = TopoExtractH5(FP_WTK, FP_WTK, s_enhance=2, t_enhance=1, - t_agg_factor=1, s_agg_factor=agg_factor, - target=TARGET, shape=SHAPE) - hr_elev = te.data +def get_lat_lon_range_h5(fp): + """Get the min/max lat/lon from an h5 file""" + with Resource(fp) as wtk: + lat_range = (wtk.meta['latitude'].min(), wtk.meta['latitude'].max()) + lon_range = (wtk.meta['longitude'].min(), wtk.meta['longitude'].max()) + return lat_range, lon_range + + +def get_lat_lon_range_nc(fp): + """Get the min/max lat/lon from a netcdf file""" + import xarray as xr + dset = xr.open_dataset(fp) + lat_range = (dset['lat'].values.min(), dset['lat'].values.max()) + lon_range = (dset['lon'].values.min(), dset['lon'].values.max()) + return lat_range, lon_range - tree = KDTree(te.source_lat_lon) - # bottom left - _, i = tree.query(TARGET, k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[-1, 0]) +def make_topo_file(fp, td, N=100, offset=0.1): + """Make a dummy h5 file with high-res topo for testing""" - # top right - _, i = tree.query((39.35, -105.2), k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[0, 0]) + if fp.endswith('.h5'): + lat_range, lon_range = get_lat_lon_range_h5(fp) + else: + lat_range, lon_range = get_lat_lon_range_nc(fp) - for idy in range(10, 20): - for idx in range(10, 20): - lat, lon = te.hr_lat_lon[idy, idx, :] - _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[idy, idx]) + lat = np.linspace(lat_range[0] - offset, lat_range[1] + offset, N) + lon = np.linspace(lon_range[0] - offset, lon_range[1] + offset, N) + idy, idx = np.meshgrid(np.arange(len(lon)), np.arange(len(lat))) + lon, lat = np.meshgrid(lon, lat) + lon, lat = lon.flatten(), lat.flatten() + idy, idx = idy.flatten(), idx.flatten() + scale = 30 + elevation = np.sin(scale * np.deg2rad(idy) + scale * np.deg2rad(idx)) + meta = pd.DataFrame({'latitude': lat, 'longitude': lon, + 'elevation': elevation}) - if plot: - a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_data, marker='s', s=150) - plt.colorbar(a) - plt.savefig('./source_elevation.png') - plt.close() + fp_temp = os.path.join(td, 'elevation.h5') + with Outputs(fp_temp, mode='w') as out: + out.meta = meta - a = plt.imshow(hr_elev) - plt.colorbar(a) - plt.savefig('./hr_elev.png') - plt.close() + return fp_temp -@pytest.mark.parametrize('agg_factor', [1, 4, 8]) -def test_topo_extraction_nc(agg_factor, plot=False): +@pytest.mark.parametrize('s_enhance', [1, 2]) +def test_topo_extraction_h5(s_enhance, plot=False): """Test the spatial enhancement of a test grid and then the lookup of the - elevation data to a reference WRF file (also the same file for the test)""" - te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=2, t_enhance=1, - s_agg_factor=agg_factor, t_agg_factor=1, - target=WRF_TARGET, shape=WRF_SHAPE) - hr_elev = te.data + elevation data to a reference WTK file (also the same file for the test)""" + with tempfile.TemporaryDirectory() as td: + fp_exo_topo = make_topo_file(FP_WTK, td) + + te = TopoExtractH5(FP_WTK, fp_exo_topo, s_enhance=s_enhance, + t_enhance=1, t_agg_factor=1, + target=TARGET, shape=SHAPE) + + hr_elev = te.data + + lat = te.hr_lat_lon[..., 0].flatten() + lon = te.hr_lat_lon[..., 1].flatten() + hr_wtk_meta = np.vstack((lat, lon)).T + hr_wtk_ind = np.arange(len(lat)).reshape(te.hr_shape[:-1]) + assert te.nn.max() == len(hr_wtk_meta) + + for gid in np.random.choice(len(hr_wtk_meta), 50, replace=False): + idy, idx = np.where(hr_wtk_ind == gid) + iloc = np.where(te.nn == gid)[0] + exo_coords = te.source_lat_lon[iloc] + + # make sure all mapped high-res exo coordinates are closest to gid + # pylint: disable=consider-using-enumerate + for i in range(len(exo_coords)): + dist = hr_wtk_meta - exo_coords[i] + dist = np.hypot(dist[:, 0], dist[:, 1]) + assert np.argmin(dist) == gid + + # make sure the mean elevation makes sense + test_out = hr_elev[idy, idx, 0, 0] + true_out = te.source_data[iloc].mean() + assert np.allclose(test_out, true_out) + + shutil.rmtree('./exo_cache/', ignore_errors=True) + + if plot: + a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], + c=te.source_data, marker='s', s=5) + plt.colorbar(a) + plt.savefig(f'./source_elevation_{s_enhance}.png') + plt.close() + + a = plt.imshow(hr_elev[:, :, 0, 0]) + plt.colorbar(a) + plt.savefig(f'./hr_elev_{s_enhance}.png') + plt.close() + + +def test_bad_s_enhance(s_enhance=10): + """Test a large s_enhance factor that results in a bad mapping with + enhanced grid pixels not having source exo data points""" + with tempfile.TemporaryDirectory() as td: + fp_exo_topo = make_topo_file(FP_WTK, td) + + with pytest.warns(UserWarning) as warnings: + te = TopoExtractH5(FP_WTK, fp_exo_topo, s_enhance=s_enhance, + t_enhance=1, t_agg_factor=1, + target=TARGET, shape=SHAPE, + cache_data=False) + _ = te.data + + good = ['target pixels did not have unique' in str(w.message) + for w in warnings.list] + assert any(good) + + +def test_topo_extraction_nc(): + """Test the spatial enhancement of a test grid and then the lookup of the + elevation data to a reference WRF file (also the same file for the test) - tree = KDTree(te.source_lat_lon) - - # bottom left - _, i = tree.query(WRF_TARGET, k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[-1, 0]) - - # top right - _, i = tree.query((19.4, -123.6), k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[0, 0]) - - for idy in range(4, 8): - for idx in range(4, 8): - lat, lon = te.hr_lat_lon[idy, idx, :] - _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_data[i].mean() - assert np.allclose(elev, hr_elev[idy, idx]) - - if plot: - a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_data, marker='s', s=150) - plt.colorbar(a) - plt.savefig('./source_elevation.png') - plt.close() - - a = plt.imshow(hr_elev) - plt.colorbar(a) - plt.savefig('./hr_elev.png') - plt.close() + We already test proper topo mapping and aggregation in the h5 test so this + just makes sure that the data can be extracted from a WRF file. + """ + te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=1, t_enhance=1, + t_agg_factor=1, target=None, shape=None) + hr_elev = te.data + assert np.allclose(te.source_data.flatten(), hr_elev.flatten()) diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index c6bfd4d32..747eeae36 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -2,6 +2,7 @@ """pytests for data handling""" import json import os +import shutil import tempfile import matplotlib.pyplot as plt @@ -963,6 +964,8 @@ def test_fwp_multi_step_model_multi_exo(): 'U_100m', 'V_100m', 'topography' ] + shutil.rmtree('./exo_cache', ignore_errors=True) + def test_fwp_multi_step_exo_hi_res_topo_and_sza(): """Test the forward pass with multiple ExoGan models requiring @@ -1197,3 +1200,5 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): for fp in handler.out_files: assert os.path.exists(fp) + + shutil.rmtree('./exo_cache', ignore_errors=True)