diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 76f43fced..098c8005c 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -8,19 +8,17 @@ import os from abc import abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed -from glob import glob -from pathlib import Path import h5py import numpy as np import pandas as pd import rex -from rex.utilities.fun_utils import get_fun_call_str from rex.utilities.bc_utils import ( + sample_q_invlog, sample_q_linear, sample_q_log, - sample_q_invlog, ) +from rex.utilities.fun_utils import get_fun_call_str from scipy import stats from scipy.ndimage import gaussian_filter from scipy.spatial import KDTree @@ -29,39 +27,11 @@ from sup3r.preprocessing.data_handling.base import DataHandler from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI -from sup3r.utilities.utilities import nn_fill_array +from sup3r.utilities.utilities import expand_paths, nn_fill_array logger = logging.getLogger(__name__) -def _expand_paths(fps): - """Expand path(s) - - Parameter - --------- - fps : str or pathlib.Path or any Sequence of those - One or multiple paths to file - - Returns - ------- - list[str] - A list of expanded unique and sorted paths as str - - Examples - -------- - >>> _expand_paths("myfile.h5") - - >>> _expand_paths(["myfile.h5", "*.hdf"]) - """ - if isinstance(fps, (str, Path)): - fps = (fps, ) - - out = [] - for f in fps: - out.extend(glob(f)) - return sorted(set(out)) - - class DataRetrievalBase: """Base class to handle data retrieval for the biased data and the baseline data @@ -163,8 +133,8 @@ class to be retrieved from the rex/sup3r library. If a self._distance_upper_bound = distance_upper_bound self.match_zero_rate = match_zero_rate - self.base_fps = _expand_paths(self.base_fps) - self.bias_fps = _expand_paths(self.bias_fps) + self.base_fps = expand_paths(self.base_fps) + self.bias_fps = expand_paths(self.bias_fps) base_sup3r_handler = getattr(sup3r.preprocessing.data_handling, base_handler, None) @@ -1224,6 +1194,7 @@ class QuantileDeltaMappingCorrection(DataRetrievalBase): :func:`~sup3r.bias.bias_transforms.local_qdm_bc` to actually correct a dataset. """ + def __init__(self, base_fps, bias_fps, @@ -1308,7 +1279,7 @@ def __init__(self, self.bias_fut_fps = bias_fut_fps - self.bias_fut_fps = _expand_paths(self.bias_fut_fps) + self.bias_fut_fps = expand_paths(self.bias_fut_fps) self.bias_fut_dh = self.bias_handler(self.bias_fut_fps, [self.bias_feature], diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 5c5a69136..091897b33 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -2,7 +2,6 @@ @author: bbenton """ -import glob import logging import os import pickle @@ -18,6 +17,7 @@ from sup3r.utilities.utilities import ( estimate_max_workers, + expand_paths, get_source_type, ignore_case_path_fetch, uniform_box_sampler, @@ -644,22 +644,15 @@ def file_paths(self, file_paths): ---------- file_paths : str | list A list of files to extract raster data from. Each file must have - the same number of timesteps. Can also pass a string with a - unix-style file path which will be passed through glob.glob + the same number of timesteps. Can also pass a string or list of + strings with a unix-style file path which will be passed through + glob.glob """ - self._file_paths = file_paths - if isinstance(self._file_paths, str): - if '*' in file_paths: - self._file_paths = glob.glob(self._file_paths) - else: - self._file_paths = [self._file_paths] - + self._file_paths = expand_paths(file_paths) msg = ('No valid files provided to DataHandler. ' f'Received file_paths={file_paths}. Aborting.') assert file_paths is not None and len(self._file_paths) > 0, msg - self._file_paths = sorted(self._file_paths) - @property def ti_workers(self): """Get max number of workers for computing time index""" diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 455da426e..8606ac769 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -731,7 +731,7 @@ def run_year(cls, area, levels, combined_out_pattern, - combined_yearly_file, + combined_yearly_file=None, interp_out_pattern=None, interp_yearly_file=None, run_interp=True, @@ -834,10 +834,13 @@ def run_year(cls, logger.info(f'Finished future for year {v["year"]} and month ' f'{v["month"]}.') - cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file) + if combined_yearly_file is not None: + cls.make_yearly_file(year, combined_out_pattern, + combined_yearly_file) - if run_interp: - cls.make_yearly_file(year, interp_out_pattern, interp_yearly_file) + if run_interp and interp_yearly_file is not None: + cls.make_yearly_file(year, interp_out_pattern, + interp_yearly_file) @classmethod def make_yearly_file(cls, year, file_pattern, yearly_file): @@ -863,7 +866,7 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): ] if not os.path.exists(yearly_file): - with xr.open_mfdataset(files) as res: + with xr.open_mfdataset(files, parallel=True) as res: logger.info(f'Combining {files}') os.makedirs(os.path.dirname(yearly_file), exist_ok=True) res.to_netcdf(yearly_file) diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index fa086a230..c50eede76 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -199,6 +199,7 @@ def get_all_queries(self, max_workers=None): else: logger.info('Querying all coordinates in parallel.') self._parallel_queries(max_workers=max_workers) + logger.info('Finished querying all coordinates.') def _serial_queries(self): """Get indices and distances for all points in target_meta, in diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index b0ac20a62..0479c88f3 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -10,6 +10,7 @@ import string import time from fnmatch import fnmatch +from pathlib import Path from warnings import warn import numpy as np @@ -55,6 +56,34 @@ def __call__(self, fun, *args, **kwargs): return out +def expand_paths(fps): + """Expand path(s) + + Parameter + --------- + fps : str or pathlib.Path or any Sequence of those + One or multiple paths to file + + Returns + ------- + list[str] + A list of expanded unique and sorted paths as str + + Examples + -------- + >>> expand_paths("myfile.h5") + + >>> expand_paths(["myfile.h5", "*.hdf"]) + """ + if isinstance(fps, (str, Path)): + fps = (fps, ) + + out = [] + for f in fps: + out.extend(glob(f)) + return sorted(set(out)) + + def generate_random_string(length): """Generate random string with given length. Used for naming temporary files to avoid collisions."""