Skip to content

Commit

Permalink
adding expand_paths call to base data handler
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed May 1, 2024
1 parent 7bb248a commit f430ff8
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 53 deletions.
43 changes: 7 additions & 36 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,17 @@
import os
from abc import abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed
from glob import glob
from pathlib import Path

import h5py
import numpy as np
import pandas as pd
import rex
from rex.utilities.fun_utils import get_fun_call_str
from rex.utilities.bc_utils import (
sample_q_invlog,
sample_q_linear,
sample_q_log,
sample_q_invlog,
)
from rex.utilities.fun_utils import get_fun_call_str
from scipy import stats
from scipy.ndimage import gaussian_filter
from scipy.spatial import KDTree
Expand All @@ -29,39 +27,11 @@
from sup3r.preprocessing.data_handling.base import DataHandler
from sup3r.utilities import VERSION_RECORD, ModuleName
from sup3r.utilities.cli import BaseCLI
from sup3r.utilities.utilities import nn_fill_array
from sup3r.utilities.utilities import expand_paths, nn_fill_array

logger = logging.getLogger(__name__)


def _expand_paths(fps):
"""Expand path(s)
Parameter
---------
fps : str or pathlib.Path or any Sequence of those
One or multiple paths to file
Returns
-------
list[str]
A list of expanded unique and sorted paths as str
Examples
--------
>>> _expand_paths("myfile.h5")
>>> _expand_paths(["myfile.h5", "*.hdf"])
"""
if isinstance(fps, (str, Path)):
fps = (fps, )

out = []
for f in fps:
out.extend(glob(f))
return sorted(set(out))


class DataRetrievalBase:
"""Base class to handle data retrieval for the biased data and the
baseline data
Expand Down Expand Up @@ -163,8 +133,8 @@ class to be retrieved from the rex/sup3r library. If a
self._distance_upper_bound = distance_upper_bound
self.match_zero_rate = match_zero_rate

self.base_fps = _expand_paths(self.base_fps)
self.bias_fps = _expand_paths(self.bias_fps)
self.base_fps = expand_paths(self.base_fps)
self.bias_fps = expand_paths(self.bias_fps)

base_sup3r_handler = getattr(sup3r.preprocessing.data_handling,
base_handler, None)
Expand Down Expand Up @@ -1224,6 +1194,7 @@ class QuantileDeltaMappingCorrection(DataRetrievalBase):
:func:`~sup3r.bias.bias_transforms.local_qdm_bc` to actually correct
a dataset.
"""

def __init__(self,
base_fps,
bias_fps,
Expand Down Expand Up @@ -1308,7 +1279,7 @@ def __init__(self,

self.bias_fut_fps = bias_fut_fps

self.bias_fut_fps = _expand_paths(self.bias_fut_fps)
self.bias_fut_fps = expand_paths(self.bias_fut_fps)

self.bias_fut_dh = self.bias_handler(self.bias_fut_fps,
[self.bias_feature],
Expand Down
17 changes: 5 additions & 12 deletions sup3r/preprocessing/data_handling/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
@author: bbenton
"""

import glob
import logging
import os
import pickle
Expand All @@ -18,6 +17,7 @@

from sup3r.utilities.utilities import (
estimate_max_workers,
expand_paths,
get_source_type,
ignore_case_path_fetch,
uniform_box_sampler,
Expand Down Expand Up @@ -644,22 +644,15 @@ def file_paths(self, file_paths):
----------
file_paths : str | list
A list of files to extract raster data from. Each file must have
the same number of timesteps. Can also pass a string with a
unix-style file path which will be passed through glob.glob
the same number of timesteps. Can also pass a string or list of
strings with a unix-style file path which will be passed through
glob.glob
"""
self._file_paths = file_paths
if isinstance(self._file_paths, str):
if '*' in file_paths:
self._file_paths = glob.glob(self._file_paths)
else:
self._file_paths = [self._file_paths]

self._file_paths = expand_paths(file_paths)
msg = ('No valid files provided to DataHandler. '
f'Received file_paths={file_paths}. Aborting.')
assert file_paths is not None and len(self._file_paths) > 0, msg

self._file_paths = sorted(self._file_paths)

@property
def ti_workers(self):
"""Get max number of workers for computing time index"""
Expand Down
13 changes: 8 additions & 5 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ def run_year(cls,
area,
levels,
combined_out_pattern,
combined_yearly_file,
combined_yearly_file=None,
interp_out_pattern=None,
interp_yearly_file=None,
run_interp=True,
Expand Down Expand Up @@ -834,10 +834,13 @@ def run_year(cls,
logger.info(f'Finished future for year {v["year"]} and month '
f'{v["month"]}.')

cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file)
if combined_yearly_file is not None:
cls.make_yearly_file(year, combined_out_pattern,
combined_yearly_file)

if run_interp:
cls.make_yearly_file(year, interp_out_pattern, interp_yearly_file)
if run_interp and interp_yearly_file is not None:
cls.make_yearly_file(year, interp_out_pattern,
interp_yearly_file)

@classmethod
def make_yearly_file(cls, year, file_pattern, yearly_file):
Expand All @@ -863,7 +866,7 @@ def make_yearly_file(cls, year, file_pattern, yearly_file):
]

if not os.path.exists(yearly_file):
with xr.open_mfdataset(files) as res:
with xr.open_mfdataset(files, parallel=True) as res:
logger.info(f'Combining {files}')
os.makedirs(os.path.dirname(yearly_file), exist_ok=True)
res.to_netcdf(yearly_file)
Expand Down
1 change: 1 addition & 0 deletions sup3r/utilities/regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def get_all_queries(self, max_workers=None):
else:
logger.info('Querying all coordinates in parallel.')
self._parallel_queries(max_workers=max_workers)
logger.info('Finished querying all coordinates.')

def _serial_queries(self):
"""Get indices and distances for all points in target_meta, in
Expand Down
29 changes: 29 additions & 0 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import string
import time
from fnmatch import fnmatch
from pathlib import Path
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -55,6 +56,34 @@ def __call__(self, fun, *args, **kwargs):
return out


def expand_paths(fps):
"""Expand path(s)
Parameter
---------
fps : str or pathlib.Path or any Sequence of those
One or multiple paths to file
Returns
-------
list[str]
A list of expanded unique and sorted paths as str
Examples
--------
>>> expand_paths("myfile.h5")
>>> expand_paths(["myfile.h5", "*.hdf"])
"""
if isinstance(fps, (str, Path)):
fps = (fps, )

out = []
for f in fps:
out.extend(glob(f))
return sorted(set(out))


def generate_random_string(length):
"""Generate random string with given length. Used for naming temporary
files to avoid collisions."""
Expand Down

0 comments on commit f430ff8

Please sign in to comment.