Skip to content

Commit

Permalink
Merge pull request #128 from NREL/bnb/dev
Browse files Browse the repository at this point in the history
Combined separate worker args to worker_kwargs dict and refactored dependent classes. Vectorized a few for loops.
  • Loading branch information
bnb32 committed Dec 21, 2022
2 parents f7605af + 97c2772 commit 0ff1d11
Show file tree
Hide file tree
Showing 25 changed files with 776 additions and 621 deletions.
13 changes: 0 additions & 13 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,10 @@ disable=
unpacking-in-except,
old-raise-syntax,
backtick,
long-suffix,
old-ne-operator,
old-octal-literal,
import-star-module-level,
non-ascii-bytes-literal,
raw-checker-failed,
bad-inline-option,
locally-disabled,
locally-enabled,
file-ignored,
suppressed-message,
useless-suppression,
Expand Down Expand Up @@ -113,7 +108,6 @@ disable=
range-builtin-not-iterating,
filter-builtin-not-iterating,
using-cmp-argument,
eq-without-hash,
div-method,
idiv-method,
rdiv-method,
Expand Down Expand Up @@ -335,13 +329,6 @@ max-line-length=79
# Maximum number of lines in a module
max-module-lines=1000

# List of optional constructs for which whitespace checking is disabled. `dict-
# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}.
# `trailing-comma` allows a space between comma and closing bracket: (a, ).
# `empty-line` allows space-only lines.
no-space-check=trailing-comma,
dict-separator

# Allow the body of a class to be on the same line as the declaration if body
# contains single statement.
single-line-class-stmt=no
Expand Down
48 changes: 32 additions & 16 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,29 @@ def stdevs(self):
"""
return self._stdevs

@property
def output_stdevs(self):
"""Get the data normalization standard deviation values for only the
output features
Returns
-------
np.ndarray
"""
indices = [self.training_features.index(f)
for f in self.output_features]
return self._stdevs[indices]

@property
def output_means(self):
"""Get the data normalization mean values for only the output features
Returns
-------
np.ndarray
"""
indices = [self.training_features.index(f)
for f in self.output_features]
return self._means[indices]

def set_norm_stats(self, new_means, new_stdevs):
"""Set the normalization statistics associated with a data batch
handler to model attributes.
Expand Down Expand Up @@ -321,17 +344,15 @@ def norm_input(self, low_res):
if isinstance(low_res, tf.Tensor):
low_res = low_res.numpy()

low_res = low_res.copy()
for idf in range(low_res.shape[-1]):
low_res[..., idf] -= self._means[idf]
if any(self._stdevs == 0):
stdevs = np.where(self._stdevs == 0, 1, self._stdevs)
msg = ('Some standard deviations are zero.')
logger.warning(msg)
warn(msg)
else:
stdevs = self._stdevs

if self._stdevs[idf] != 0:
low_res[..., idf] /= self._stdevs[idf]
else:
msg = ('Standard deviation is zero for '
f'{self.training_features[idf]}')
logger.warning(msg)
warn(msg)
low_res = (low_res.copy() - self._means) / stdevs

return low_res

Expand All @@ -352,12 +373,7 @@ def un_norm_output(self, output):
if isinstance(output, tf.Tensor):
output = output.numpy()

for idf in range(output.shape[-1]):
feature_name = self.output_features[idf]
i = self.training_features.index(feature_name)
mean = self._means[i]
stdev = self._stdevs[i]
output[..., idf] = (output[..., idf] * stdev) + mean
output = (output * self.output_stdevs) + self.output_means

return output

Expand Down
4 changes: 1 addition & 3 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,7 @@ def discriminate(self, hi_res, norm_in=False):

if norm_in and self._means is not None:
hi_res = hi_res if isinstance(hi_res, tf.Tensor) else hi_res.copy()
for i, (m, s) in enumerate(zip(self._means, self._stdevs)):
islice = tuple([slice(None)] * (len(hi_res.shape) - 1) + [i])
hi_res[islice] = (hi_res[islice] - m) / s
hi_res = (hi_res - self._means) / self._stdevs

out = self.discriminator.layers[0](hi_res)
for i, layer in enumerate(self.discriminator.layers[1:]):
Expand Down
169 changes: 86 additions & 83 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,8 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
input_handler=None,
input_handler_kwargs=None,
incremental=True,
max_workers=None,
output_workers=None,
worker_kwargs=None,
exo_kwargs=None,
pass_workers=1,
bias_correct_method=None,
bias_correct_kwargs=None,
max_nodes=None):
Expand Down Expand Up @@ -633,21 +631,25 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
Allow the forward pass iteration to skip spatiotemporal chunks that
already have an output file (True, default) or iterate through all
chunks and overwrite any pre-existing outputs (False).
max_workers : int | None
Providing a value for max workers will be used to set the value of
extract_workers, compute_workers, output_workers, load_workers,
ti_workers, pass_workers. If max_workers == 1 then all processes
will be serialized. If None extract_workers, compute_workers,
load_workers, output_workers, etc will use their own provided
values.
output_workers : int | None
max number of workers to use for writing forward pass output.
pass_workers : int | None
max number of workers to use for performing forward passes on a
single node. If 1 then all forward passes on chunks distributed to
a single node will be run in serial. pass_workers=2 is the minimum
number of workers required to run the ForwardPass initialization
and ForwardPass.run_chunk() methods concurrently.
worker_kwargs : dict | None
Dictionary of worker values. Can include max_workers,
pass_workers, output_workers, and ti_workers. Each argument needs
to be an integer or None.
The value of `max workers` will set the value of all other worker
args. If max_workers == 1 then all processes will be serialized. If
max_workers == None then other worker args will use their own
provided values.
`output_workers` is the max number of workers to use for writing
forward pass output. `pass_workers` is the max number of workers to
use for performing forward passes on a single node. If 1 then all
forward passes on chunks distributed to a single node will be run
in serial. pass_workers=2 is the minimum number of workers required
to run the ForwardPass initialization and ForwardPass.run_chunk()
methods concurrently. `ti_workers` is the max number of workers
used to get the full time index. Doing this is parallel can be
helpful when there are a large number of input files.
exo_kwargs : dict | None
Dictionary of args to pass to ExogenousDataHandler for extracting
exogenous features such as topography for future multistep foward
Expand All @@ -669,64 +671,51 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
Maximum number of nodes to distribute spatiotemporal chunks across.
If None then a node will be used for each temporal chunk.
"""
self._input_handler_kwargs = input_handler_kwargs or {}
target = self._input_handler_kwargs.get('target', None)
grid_shape = self._input_handler_kwargs.get('shape', None)
raster_file = self._input_handler_kwargs.get('raster_file', None)
raster_index = self._input_handler_kwargs.get('raster_index', None)
temporal_slice = self._input_handler_kwargs.get('temporal_slice',
slice(None, None, 1))
InputMixIn.__init__(self, target=target, shape=grid_shape,
raster_file=raster_file, raster_index=raster_index,
temporal_slice=temporal_slice)

self.file_paths = file_paths
self.model_kwargs = model_kwargs
self.fwp_chunk_shape = fwp_chunk_shape
self.spatial_pad = spatial_pad
self.temporal_pad = temporal_pad
self.model_class = model_class
self.out_pattern = out_pattern
self.max_workers = max_workers
self.output_workers = output_workers
self.pass_workers = pass_workers
self.worker_kwargs = worker_kwargs or {}
self.exo_kwargs = exo_kwargs or {}
self.incremental = incremental
self.bias_correct_method = bias_correct_method
self.bias_correct_kwargs = bias_correct_kwargs or {}
self._failed_chunks = False
self._input_handler_class = None
self._input_handler_name = input_handler
self._max_nodes = max_nodes
self._input_handler_kwargs = input_handler_kwargs or {}
self._time_index = None
self._raw_time_index = None
self._raw_tsteps = None
self._out_files = None
self._file_ids = None
self._time_index_file = None
self._node_chunks = None
self._hr_lat_lon = None
self._lr_lat_lon = None
self._init_handler = None
self._handle_features = None
self.bias_correct_method = bias_correct_method
self.bias_correct_kwargs = bias_correct_kwargs or {}

self._single_ts_files = self._input_handler_kwargs.get(
'single_ts_files', None)
self._target = self._input_handler_kwargs.get('target', None)
self._grid_shape = self._input_handler_kwargs.get('shape', None)
self.raster_file = self._input_handler_kwargs.get('raster_file', None)
self.temporal_slice = self._input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = self._input_handler_kwargs.get(
'time_chunk_size', None)
self.overwrite_cache = self._input_handler_kwargs.get(
'overwrite_cache', False)
self.overwrite_ti_cache = self._input_handler_kwargs.get(
'overwrite_ti_cache', False)
self.extract_workers = self._input_handler_kwargs.get(
'extract_workers', None)
self.compute_workers = self._input_handler_kwargs.get(
'compute_workers', None)
self.load_workers = self._input_handler_kwargs.get('load_workers',
None)
self.ti_workers = self._input_handler_kwargs.get('ti_workers', None)
self.res_kwargs = self._input_handler_kwargs.get('res_kwargs', {})
self._cache_pattern = self._input_handler_kwargs.get('cache_pattern',
None)
self._worker_attrs = ['ti_workers', 'compute_workers', 'pass_workers',
'load_workers', 'output_workers',
'extract_workers']
self.cap_worker_args(max_workers)
self.cache_pattern = self._input_handler_kwargs.get('cache_pattern',
None)
self.max_workers = self.worker_kwargs.get('max_workers', None)
self.output_workers = self.worker_kwargs.get('output_workers', None)
self.pass_workers = self.worker_kwargs.get('pass_workers', None)
self.ti_workers = self.worker_kwargs.get('ti_workers', None)
self._worker_attrs = ['pass_workers', 'output_workers', 'ti_workers']
self.cap_worker_args(self.max_workers)

model_class = getattr(sup3r.models, self.model_class, None)
if isinstance(self.model_kwargs, str):
Expand All @@ -747,7 +736,8 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
self.t_enhance = np.product(self.t_enhancements)
self.output_features = model.output_features

self.fwp_slicer = ForwardPassSlicer(self.grid_shape, self.raw_tsteps,
self.fwp_slicer = ForwardPassSlicer(self.grid_shape,
self.raw_tsteps,
self.temporal_slice,
self.fwp_chunk_shape,
self.s_enhancements,
Expand Down Expand Up @@ -821,12 +811,8 @@ def preflight(self):
f'and n_total_chunks={self.chunks}. '
f'{self.chunks / self.nodes} chunks per node on average.')
logger.info(f'Using max_workers={self.max_workers}, '
f'extract_workers={self.extract_workers}, '
f'compute_workers={self.compute_workers}, '
f'pass_workers={self.pass_workers}, '
f'load_workers={self.load_workers}, '
f'output_workers={self.output_workers}, '
f'ti_workers={self.ti_workers}')
f'output_workers={self.output_workers}')

out = self.fwp_slicer.get_temporal_slices()
self.ti_slices, self.ti_pad_slices = out
Expand Down Expand Up @@ -857,7 +843,7 @@ def init_handler(self):
out = self.input_handler_class(self.file_paths[0], [],
target=self.target,
shape=self.grid_shape,
ti_workers=1)
worker_kwargs=dict(ti_workers=1))
self._init_handler = out
return self._init_handler

Expand All @@ -871,7 +857,7 @@ def lr_lat_lon(self):

@property
def handle_features(self):
"""Get available handle features"""
"""Get list of features available in the source data"""
if self._handle_features is None:
if self.single_ts_files:
self._handle_features = self.init_handler.handle_features
Expand All @@ -895,8 +881,14 @@ def get_full_domain(self, file_paths):
"""Get target and grid_shape for largest possible domain"""
return self.input_handler_class.get_full_domain(file_paths)

def get_lat_lon(self, file_paths, raster_index, invert_lat=False):
"""Get lat/lon grid for requested target and shape"""
return self.input_handler_class.get_lat_lon(file_paths, raster_index,
invert_lat=invert_lat)

def get_time_index(self, file_paths, max_workers=None, **kwargs):
"""Get time index for source data
"""Get time index for source data using DataHandler.get_time_index
method
Parameters
----------
Expand Down Expand Up @@ -1140,27 +1132,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
elif strategy.output_type == 'h5':
self.output_handler_class = OutputHandlerH5

input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs)
fwp_input_handler_kwargs = dict(
file_paths=self.file_paths,
features=self.features,
target=self.target,
shape=self.shape,
temporal_slice=self.temporal_pad_slice,
raster_file=self.raster_file,
cache_pattern=self.cache_pattern,
time_chunk_size=self.strategy.time_chunk_size,
overwrite_cache=self.strategy.overwrite_cache,
max_workers=self.max_workers,
extract_workers=strategy.extract_workers,
compute_workers=strategy.compute_workers,
load_workers=strategy.load_workers,
ti_workers=strategy.ti_workers,
handle_features=strategy.handle_features,
res_kwargs=strategy.res_kwargs,
single_ts_files=strategy.single_ts_files,
val_split=0.0)
input_handler_kwargs.update(fwp_input_handler_kwargs)
input_handler_kwargs = self.update_input_handler_kwargs(strategy)

logger.info(f'Getting input data for chunk_index={chunk_index}.')
self.data_handler = self.input_handler_class(**input_handler_kwargs)
Expand All @@ -1174,6 +1146,37 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
self.exogenous_data, exo_s_en)
self.input_data, self.exogenous_data = out

def update_input_handler_kwargs(self, strategy):
"""Update the kwargs for the input handler for the current forward pass
chunk
Parameters
----------
strategy : ForwardPassStrategy
ForwardPassStrategy instance with information on data chunks to run
forward passes on.
Returns
-------
dict
Updated dictionary of input handler arguments to pass to the
data handler for the current forward pass chunk
"""
input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs)
fwp_input_handler_kwargs = dict(
file_paths=self.file_paths,
features=self.features,
target=self.target,
shape=self.shape,
temporal_slice=self.temporal_pad_slice,
raster_file=self.raster_file,
cache_pattern=self.cache_pattern,
single_ts_files=self.single_ts_files,
handle_features=strategy.handle_features,
val_split=0.0)
input_handler_kwargs.update(fwp_input_handler_kwargs)
return input_handler_kwargs

@property
def single_ts_files(self):
"""Get whether input files are single time step or not"""
Expand Down
Loading

0 comments on commit 0ff1d11

Please sign in to comment.