From 86df1b4385cce20fb406505cdafafab6c70261cf Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 11 Jan 2023 06:57:49 -0700 Subject: [PATCH 1/4] refactored fwp strat and regridding to remove duplicate code. changed inheritance ordering for wind models to remove tf_generate_wind/generate_wind --- sup3r/models/abstract.py | 11 +- sup3r/models/wind.py | 68 +------------ sup3r/models/wind_conditional_moments.py | 68 +------------ sup3r/pipeline/forward_pass.py | 115 +++------------------ sup3r/utilities/regridder.py | 58 +++-------- sup3r/utilities/utilities.py | 123 +++++++++++++++++++++++ tests/test_utilities.py | 2 +- 7 files changed, 163 insertions(+), 282 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index a2546b7ff..e54d5f5a5 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -909,9 +909,8 @@ class AbstractWindInterface(ABC): Abstract class to define the required training interface for Sup3r wind model subclasses """ - - @staticmethod - def set_model_params_wind(**kwargs): + # pylint: disable=E0211 + def set_model_params(**kwargs): """Set parameters used for training the model Parameters @@ -1002,8 +1001,8 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): return hi_res_topo - def generate_wind(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): + def generate(self, low_res, norm_in=True, un_norm_out=True, + exogenous_data=None): """Use the generator model to generate high res data from low res input. This is the public generate function. @@ -1076,7 +1075,7 @@ def generate_wind(self, low_res, norm_in=True, un_norm_out=True, return hi_res @tf.function - def _tf_generate_wind(self, low_res, hi_res_topo): + def _tf_generate(self, low_res, hi_res_topo): """Use the generator model to generate high res data from los res input Parameters diff --git a/sup3r/models/wind.py b/sup3r/models/wind.py index e5ac8b6d7..5a545025d 100644 --- a/sup3r/models/wind.py +++ b/sup3r/models/wind.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class WindGan(Sup3rGan, AbstractWindInterface): +class WindGan(AbstractWindInterface, Sup3rGan): """Wind super resolution GAN with handling of low and high res topography inputs. @@ -69,70 +69,8 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - kwargs = self.set_model_params_wind(**kwargs) - super().set_model_params(**kwargs) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. This also normalizes - hi_res_topo. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | list | None - Exogenous data for topography inputs. The first entry in this list - (or only entry) is a low-resolution topography array that can be - concatenated to the low_res input array. The second entry is - high-resolution topography (either 2D or 4D/5D depending on if - spatial or spatiotemporal super res). - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - return self.generate_wind(low_res, norm_in, - un_norm_out, exogenous_data) - - @tf.function - def _tf_generate(self, low_res, hi_res_topo): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - hi_res_topo : np.ndarray - This should be a 4D array for spatial enhancement model or 5D array - for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, - (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. - - Returns - ------- - hi_res : tf.Tensor - Synthetically generated high-resolution data - """ - return self._tf_generate_wind(low_res, hi_res_topo) + AbstractWindInterface.set_model_params(**kwargs) + Sup3rGan.set_model_params(self, **kwargs) @tf.function() def get_single_grad(self, low_res, hi_res_true, training_weights, diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py index 5d8975c27..3d770a6c2 100644 --- a/sup3r/models/wind_conditional_moments.py +++ b/sup3r/models/wind_conditional_moments.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -class WindCondMom(Sup3rCondMom, AbstractWindInterface): +class WindCondMom(AbstractWindInterface, Sup3rCondMom): """Wind conditional moment estimator with handling of low and high res topography inputs. @@ -33,70 +33,8 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - kwargs = self.set_model_params_wind(**kwargs) - super().set_model_params(**kwargs) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. This also normalizes - hi_res_topo. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | list | None - Exogenous data for topography inputs. The first entry in this list - (or only entry) is a low-resolution topography array that can be - concatenated to the low_res input array. The second entry is - high-resolution topography (either 2D or 4D/5D depending on if - spatial or spatiotemporal super res). - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - return self.generate_wind(low_res, norm_in, - un_norm_out, exogenous_data) - - @tf.function - def _tf_generate(self, low_res, hi_res_topo): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - hi_res_topo : np.ndarray - This should be a 4D array for spatial enhancement model or 5D array - for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, - (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. - - Returns - ------- - hi_res : tf.Tensor - Synthetically generated high-resolution data - """ - return self._tf_generate_wind(low_res, hi_res_topo) + AbstractWindInterface.set_model_params(**kwargs) + Sup3rCondMom.set_model_params(self, **kwargs) @tf.function() def get_single_grad(self, low_res, hi_res_true, training_weights, diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 10192700b..622dd2bc4 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -24,7 +24,8 @@ from sup3r.postprocessing.file_handling import (OutputHandlerH5, OutputHandlerNC, OutputHandler) -from sup3r.utilities.utilities import (get_chunk_slices, +from sup3r.utilities.utilities import (DistributedProcess, + get_chunk_slices, get_source_type, get_input_handler_class) from sup3r.utilities import ModuleName @@ -549,7 +550,7 @@ def get_cropped_slices(unpadded_slices, padded_slices, enhancement): return cropped_slices -class ForwardPassStrategy(InputMixIn): +class ForwardPassStrategy(InputMixIn, DistributedProcess): """Class to prepare data for forward passes through generator. A full file list of contiguous times is provided. The corresponding data is @@ -678,8 +679,10 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, raster_index = self._input_handler_kwargs.get('raster_index', None) temporal_slice = self._input_handler_kwargs.get('temporal_slice', slice(None, None, 1)) - InputMixIn.__init__(self, target=target, shape=grid_shape, - raster_file=raster_file, raster_index=raster_index, + InputMixIn.__init__(self, target=target, + shape=grid_shape, + raster_file=raster_file, + raster_index=raster_index, temporal_slice=temporal_slice) self.file_paths = file_paths @@ -694,13 +697,9 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self.incremental = incremental self.bias_correct_method = bias_correct_method self.bias_correct_kwargs = bias_correct_kwargs or {} - self._failed_chunks = False self._input_handler_class = None self._input_handler_name = input_handler - self._max_nodes = max_nodes - self._out_files = None self._file_ids = None - self._node_chunks = None self._hr_lat_lon = None self._lr_lat_lon = None self._init_handler = None @@ -745,61 +744,11 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self.spatial_pad, self.temporal_pad) - self.preflight() - - @property - def failed_chunks(self): - """Check whether any forward passes have generated constant output.""" - return self._failed_chunks - - @failed_chunks.setter - def failed_chunks(self, failed): - """Set failed_chunks value. Will be set to True by a ForwardPass object - if there is a failed chunk""" - self._failed_chunks = failed - - def node_finished(self, node_index): - """Check if all out files for a given node have been saved - - Parameters - ---------- - node_index : int - Index of node to check for completed forward passes - - Returns - ------- - bool - Whether all forward passes for the given node have finished - """ - node_files = [self.out_files[i] for i in self.node_chunks[node_index]] - return all(os.path.exists(out_file) for out_file in node_files) - - @property - def all_finished(self): - """Check if all out files have been saved""" - return all(os.path.exists(out_file) for out_file in self.out_files) - - def chunk_finished(self, chunk_index): - """Check if forward pass for given chunk_index has already been run. + DistributedProcess.__init__(self, max_nodes=max_nodes, + max_chunks=self.fwp_slicer.n_chunks, + incremental=self.incremental) - Parameters - ---------- - chunk_index : int - Index of the chunk to check for a finished forward pass. Considered - finished if there is already an output file and incremental is - False. - - Returns - ------- - bool - Whether the forward pass for the given chunk has finished - """ - out_file = self.out_files[chunk_index] - if os.path.exists(out_file) and self.incremental: - logger.info('Not running chunk index {}, output file ' - 'exists: {}'.format(chunk_index, out_file)) - return True - return False + self.preflight() def preflight(self): """Prelight path name formatting and sanity checks""" @@ -972,47 +921,14 @@ def input_handler_class(self): self.file_paths, self._input_handler_name) return self._input_handler_class - def __len__(self): - """Get the number of nodes that this strategy is distributing to""" - return self.fwp_slicer.n_chunks - @property def max_nodes(self): """Get the maximum number of nodes that this strategy should distribute work to, equal to either the specified max number of nodes or total number of temporal chunks""" - nodes = (self._max_nodes if self._max_nodes is not None - else self.fwp_slicer.n_temporal_chunks) - nodes = np.min((nodes, self.chunks)) - return nodes - - @property - def nodes(self): - """Get the number of nodes that this strategy should distribute work - to, equal to either the specified max number of nodes or total number - of temporal chunks""" - return len(self.node_chunks) - - @property - def chunks(self): - """Get the number of spatiotemporal chunks going through forward pass, - calculated as the source time index divided by the temporal part of the - fwp_chunk_shape times the number of spatial chunks""" - return self.fwp_slicer.n_chunks - - @property - def node_chunks(self): - """Get chunked list of spatiotemporal chunk indices that will be - used to distribute sets of spatiotemporal chunks across nodes. For - example, if we want to distribute 10 spatiotemporal chunks across 2 - nodes this will return [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]. So the first - node will be used to run forward passes on the first 5 spatiotemporal - chunks and the second node will be used for the last 5""" - if self._node_chunks is None: - n_chunks = np.min((self.max_nodes, self.chunks)) - self._node_chunks = np.array_split(np.arange(self.chunks), - n_chunks) - return self._node_chunks + self._max_nodes = (self._max_nodes if self._max_nodes is not None + else self.fwp_slicer.n_temporal_chunks) + return self._max_nodes @staticmethod def get_output_file_names(out_files, file_ids): @@ -1786,8 +1702,7 @@ def run(cls, strategy, node_index): Index of node on which the forward passes for spatiotemporal chunks will be run. """ - if strategy.node_finished(node_index) and strategy.incremental: - logger.info(f'All jobs for node_index={node_index} already done.') + if strategy.node_finished(node_index): return if strategy.pass_workers == 1: diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 7aec26cc4..e115993bf 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -16,6 +16,7 @@ from sup3r.postprocessing.file_handling import OutputMixIn, RexOutputs from sup3r.postprocessing.collection import Collector from sup3r.utilities import ModuleName +from sup3r.utilities.utilities import DistributedProcess from sup3r.utilities.cli import BaseCLI logger = logging.getLogger(__name__) @@ -426,7 +427,7 @@ def regrid_coordinates(cls, index_chunk, distance_chunk, height, return ws, wd -class RegridOutput(OutputMixIn): +class RegridOutput(OutputMixIn, DistributedProcess): """Output regridded data as it is interpolated. Takes source data from windspeed and winddirection h5 files and uses this data to interpolate onto a new target grid. The interpolated data is then written to new files, with @@ -474,10 +475,6 @@ def __init__(self, source_files, out_pattern, target_meta, heights, self.query_workers = worker_kwargs.get('query_workers', None) self.source_files = (source_files if isinstance(source_files, list) else glob(source_files)) - self._n_chunks = n_chunks - self._n_nodes = max_nodes - self._node_chunks = None - self.out_pattern = out_pattern self.target_meta_path = target_meta self.target_meta = pd.read_csv(self.target_meta_path) self.target_meta['gid'] = np.arange(len(self.target_meta)) @@ -485,6 +482,7 @@ def __init__(self, source_files, out_pattern, target_meta, heights, ['latitude', 'longitude'], ascending=[False, True]) self.heights = heights self.incremental = incremental + self.out_pattern = out_pattern os.makedirs(os.path.dirname(self.out_pattern), exist_ok=True) with MultiFileResource(source_files) as res: @@ -499,6 +497,10 @@ def __init__(self, source_files, out_pattern, target_meta, heights, cache_pattern=cache_pattern, n_chunks=n_chunks, max_workers=self.query_workers) + DistributedProcess.__init__(self, max_nodes=max_nodes, + n_chunks=n_chunks, + max_chunks=len(self.regridder.indices), + incremental=incremental) logger.info('Initializing RegridOutput with ' f'source_files={self.source_files}, ' @@ -509,25 +511,6 @@ def __init__(self, source_files, out_pattern, target_meta, heights, f'n_chunks={n_chunks}.') logger.info(f'Max memory usage: {self.max_memory:.3f} GB.') - @property - def chunks(self): - """Get the number of chunks to split the target meta into """ - return min(self._n_chunks, len(self.regridder.indices)) - - @property - def nodes(self): - """Get the max number of nodes to distribute chunks across""" - return min(self._n_nodes, self.chunks) - - @property - def node_chunks(self): - """Get the chunk indices for different nodes""" - if self._node_chunks is None: - n_chunks = min(self.nodes, self.chunks) - self._node_chunks = np.array_split(np.arange(self.chunks), - n_chunks) - return self._node_chunks - @property def spatial_slices(self): """Get the list of slices which select index and distance chunks""" @@ -565,7 +548,7 @@ def meta_chunks(self): return [self.regridder.target_meta[s] for s in self.spatial_slices] @property - def output_files(self): + def out_files(self): """Get list of output files for each spatial chunk""" return [self.out_pattern.format(file_id=str(i).zfill(6)) for i in range(self.chunks)] @@ -626,6 +609,9 @@ def run(self, node_index): Node index to run. e.g. if node_index=0 then only the chunks for node_chunks[0] will be run. """ + if self.node_finished(node_index): + return + if self.regrid_workers == 1: self._run_serial(source_files=self.source_files, node_index=node_index) @@ -634,21 +620,6 @@ def run(self, node_index): node_index=node_index, max_workers=self.regrid_workers) - def collect(self, out_pattern, max_workers=None): - """Collect output chunks - - Parameters - ---------- - out_pattern : str - Output pattern for collected output files. Needs to include - {feature} key. e.g. ./collected_{feature}.h5 - """ - for feature in self.output_features: - out_file = out_pattern.format(feature=feature) - Collector.collect(self.output_files, out_file, [feature], - target_final_meta_file=self.target_meta_path, - max_workers=max_workers) - def _run_serial(self, source_files, node_index): """Regrid data and write to output file, in serial. @@ -732,12 +703,9 @@ def write_coordinates(self, source_files, chunk_index): index_chunk = self.index_chunks[chunk_index] distance_chunk = self.distance_chunks[chunk_index] s_slice = self.spatial_slices[chunk_index] - out_file = self.output_files[chunk_index] + out_file = self.out_files[chunk_index] meta = self.meta_chunks[chunk_index] - if os.path.exists(out_file) and not self.incremental: - msg = (f'{out_file} already exists and incremental=True. Skipping' - ' this chunk.') - logger.info(msg) + if self.chunk_finished(chunk_index): return tmp_file = out_file.replace('.h5', '.h5.tmp') diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index ab237a4ef..fefa227f8 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -26,6 +26,129 @@ logger = logging.getLogger(__name__) +class DistributedProcess: + """High-level class with commonly used functionality for processes + distributed across multiple nodes""" + + def __init__(self, max_nodes=1, n_chunks=None, max_chunks=None, + incremental=False): + """ + Parameters + ---------- + max_nodes : int, optional + Max number of nodes to distribute processes across + n_chunks : int, optional + Number of chunks to split all processes into. These process + chunks will be distributed across nodes. + max_chunks : int, optional + Max number of chunks processes can be split into. + incremental : bool + Whether to skip previously run process chunks or to overwrite. + """ + msg = ('For a distributed process either max_chunks or ' + 'max_chunks + n_chunks must be specified. Received ' + f'max_chunks={max_chunks}, n_chunks={n_chunks}.') + assert max_chunks is not None, msg + self._node_chunks = None + self._n_chunks = n_chunks + self._max_nodes = max_nodes + self._max_chunks = max_chunks + self._out_files = None + self._failed_chunks = False + self.incremental = incremental + + def __len__(self): + """Get total number of process chunks""" + return self.chunks + + def node_finished(self, node_index): + """Check if all out files for a given node have been saved + + Parameters + ---------- + node_index : int + Index of node to check for completed processes + + Returns + ------- + bool + Whether all processes for the given node have finished + """ + return all([self.chunk_finished(i) + for i in self.node_chunks[node_index]]) + + def chunk_finished(self, chunk_index): + """Check if process for given chunk_index has already been run. + + Parameters + ---------- + chunk_index : int + Index of the process chunk to check for completion. Considered + finished if there is already an output file and incremental is + False. + + Returns + ------- + bool + Whether the process for the given chunk has finished + """ + out_file = self.out_files[chunk_index] + if os.path.exists(out_file) and self.incremental: + logger.info('Not running chunk index {}, output file ' + 'exists: {}'.format(chunk_index, out_file)) + return True + return False + + @property + def all_finished(self): + """Check if all out files have been saved""" + return all([self.node_finished(i) for i in range(self.nodes)]) + + @property + def out_files(self): + """Get list of out files to write process output to""" + return self._out_files + + @property + def max_nodes(self): + """Get uncapped max number of nodes to distribute processes across""" + return self._max_nodes + + @property + def chunks(self): + """Get the number of processes chunks for this distributed routine.""" + if self._n_chunks is None: + return self._max_chunks + else: + return min(self._n_chunks, self._max_chunks) + + @property + def nodes(self): + """Get the max number of nodes to distribute chunks across, limited by + the number of process chunks""" + return len(self.node_chunks) + + @property + def node_chunks(self): + """Get the chunk indices for different nodes""" + if self._node_chunks is None: + n_chunks = min(self.max_nodes, self.chunks) + self._node_chunks = np.array_split(np.arange(self.chunks), + n_chunks) + return self._node_chunks + + @property + def failed_chunks(self): + """Check whether any processes have failed.""" + return self._failed_chunks + + @failed_chunks.setter + def failed_chunks(self, failed): + """Set failed_chunks value. Should be set to True if there is a failed + chunk""" + self._failed_chunks = failed + + def correct_path(path): """If running on windows we need to replace backslashes with double backslashes so paths can be parsed correctly with safe_open_json""" diff --git a/tests/test_utilities.py b/tests/test_utilities.py index e5c583bb0..052daae38 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -53,7 +53,7 @@ def test_regridding(): for node_index in range(regrid_output.nodes): regrid_output.run(node_index=node_index) - Collector.collect(regrid_output.output_files, + Collector.collect(regrid_output.out_files, collect_file, regrid_output.output_features, target_final_meta_file=meta_path, From 86f1cfd5e8ce0d5f3f04e7f6a84a998dd06fed26 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 11 Jan 2023 07:14:57 -0700 Subject: [PATCH 2/4] linting and _wind method removal --- sup3r/models/abstract.py | 6 ++-- sup3r/models/wind.py | 42 +----------------------- sup3r/models/wind_conditional_moments.py | 42 +----------------------- sup3r/utilities/regridder.py | 1 - sup3r/utilities/utilities.py | 8 ++--- 5 files changed, 9 insertions(+), 90 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index e54d5f5a5..76a80912e 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -910,7 +910,7 @@ class AbstractWindInterface(ABC): for Sup3r wind model subclasses """ # pylint: disable=E0211 - def set_model_params(**kwargs): + def set_model_params(self, **kwargs): """Set parameters used for training the model Parameters @@ -1114,8 +1114,8 @@ def _tf_generate(self, low_res, hi_res_topo): return hi_res @tf.function() - def get_single_grad_wind(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): + def get_single_grad(self, low_res, hi_res_true, training_weights, + device_name=None, **calc_loss_kwargs): """Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details. diff --git a/sup3r/models/wind.py b/sup3r/models/wind.py index 5a545025d..4fe7ce5af 100644 --- a/sup3r/models/wind.py +++ b/sup3r/models/wind.py @@ -69,49 +69,9 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(**kwargs) + AbstractWindInterface.set_model_params(self, **kwargs) Sup3rGan.set_model_params(self, **kwargs) - @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): - """Run gradient descent for one mini-batch of (low_res, hi_res_true), - do not update weights, just return gradient details. - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - hi_res_true : np.ndarray - Real high-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - training_weights : list - A list of layer weights that are to-be-trained based on the - current loss weight values. - device_name : None | str - Optional tensorflow device name for GPU placement. Note that if a - GPU is available, variables will be placed on that GPU even if - device_name=None. - calc_loss_kwargs : dict - Kwargs to pass to the self.calc_loss() method - - Returns - ------- - grad : list - a list or nested structure of Tensors (or IndexedSlices, or None, - or CompositeTensor) representing the gradients for the - training_weights - loss_details : dict - Namespace of the breakdown of loss components - """ - return self.get_single_grad_wind(low_res, hi_res_true, - training_weights, - device_name=device_name, - **calc_loss_kwargs) - @tf.function def calc_loss(self, hi_res_true, hi_res_gen, **kwargs): """Calculate the GAN loss function using generated and true high diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py index 3d770a6c2..aed7b82a7 100644 --- a/sup3r/models/wind_conditional_moments.py +++ b/sup3r/models/wind_conditional_moments.py @@ -33,49 +33,9 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(**kwargs) + AbstractWindInterface.set_model_params(self, **kwargs) Sup3rCondMom.set_model_params(self, **kwargs) - @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): - """Run gradient descent for one mini-batch of (low_res, hi_res_true), - do not update weights, just return gradient details. - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - hi_res_true : np.ndarray - Real high-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - training_weights : list - A list of layer weights that are to-be-trained based on the - current loss weight values. - device_name : None | str - Optional tensorflow device name for GPU placement. Note that if a - GPU is available, variables will be placed on that GPU even if - device_name=None. - calc_loss_kwargs : dict - Kwargs to pass to the self.calc_loss() method - - Returns - ------- - grad : list - a list or nested structure of Tensors (or IndexedSlices, or None, - or CompositeTensor) representing the gradients for the - training_weights - loss_details : dict - Namespace of the breakdown of loss components - """ - return self.get_single_grad_wind(low_res, hi_res_true, - training_weights, - device_name=device_name, - **calc_loss_kwargs) - @tf.function def calc_loss(self, hi_res_true, hi_res_gen, mask, **kwargs): """Calculate the loss function using generated and true high diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index e115993bf..f41cc2b9e 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -14,7 +14,6 @@ from rex import MultiFileResource from sup3r.postprocessing.file_handling import OutputMixIn, RexOutputs -from sup3r.postprocessing.collection import Collector from sup3r.utilities import ModuleName from sup3r.utilities.utilities import DistributedProcess from sup3r.utilities.cli import BaseCLI diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index fefa227f8..c8c618201 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -74,8 +74,8 @@ def node_finished(self, node_index): bool Whether all processes for the given node have finished """ - return all([self.chunk_finished(i) - for i in self.node_chunks[node_index]]) + return all(self.chunk_finished(i) + for i in self.node_chunks[node_index]) def chunk_finished(self, chunk_index): """Check if process for given chunk_index has already been run. @@ -102,10 +102,10 @@ def chunk_finished(self, chunk_index): @property def all_finished(self): """Check if all out files have been saved""" - return all([self.node_finished(i) for i in range(self.nodes)]) + return all(self.node_finished(i) for i in range(self.nodes)) @property - def out_files(self): + def out_files(self) -> list: """Get list of out files to write process output to""" return self._out_files From 071df010e433718f283f7582dba54ebd47d5317d Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 11 Jan 2023 07:35:04 -0700 Subject: [PATCH 3/4] linting --- sup3r/utilities/utilities.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index c8c618201..5c8da69d9 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -77,6 +77,7 @@ def node_finished(self, node_index): return all(self.chunk_finished(i) for i in self.node_chunks[node_index]) + # pylint: disable=E1136 def chunk_finished(self, chunk_index): """Check if process for given chunk_index has already been run. @@ -105,7 +106,7 @@ def all_finished(self): return all(self.node_finished(i) for i in range(self.nodes)) @property - def out_files(self) -> list: + def out_files(self): """Get list of out files to write process output to""" return self._out_files From 6bbe02e02254e613c28b6a921a0d13367ddb068b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 11 Jan 2023 10:42:04 -0700 Subject: [PATCH 4/4] pr review changes --- sup3r/models/abstract.py | 3 +- sup3r/models/wind.py | 2 +- sup3r/models/wind_conditional_moments.py | 2 +- sup3r/pipeline/forward_pass.py | 4 +- sup3r/utilities/execution.py | 135 +++++++++++++++++++++++ sup3r/utilities/regridder.py | 2 +- sup3r/utilities/utilities.py | 127 +-------------------- 7 files changed, 143 insertions(+), 132 deletions(-) create mode 100644 sup3r/utilities/execution.py diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 76a80912e..0cf8cb5dd 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -910,7 +910,8 @@ class AbstractWindInterface(ABC): for Sup3r wind model subclasses """ # pylint: disable=E0211 - def set_model_params(self, **kwargs): + @staticmethod + def set_model_params(**kwargs): """Set parameters used for training the model Parameters diff --git a/sup3r/models/wind.py b/sup3r/models/wind.py index 4fe7ce5af..46631656b 100644 --- a/sup3r/models/wind.py +++ b/sup3r/models/wind.py @@ -69,7 +69,7 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(self, **kwargs) + AbstractWindInterface.set_model_params(**kwargs) Sup3rGan.set_model_params(self, **kwargs) @tf.function diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py index aed7b82a7..cf3a3dc2c 100644 --- a/sup3r/models/wind_conditional_moments.py +++ b/sup3r/models/wind_conditional_moments.py @@ -33,7 +33,7 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(self, **kwargs) + AbstractWindInterface.set_model_params(**kwargs) Sup3rCondMom.set_model_params(self, **kwargs) @tf.function diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 622dd2bc4..d593337f7 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -24,10 +24,10 @@ from sup3r.postprocessing.file_handling import (OutputHandlerH5, OutputHandlerNC, OutputHandler) -from sup3r.utilities.utilities import (DistributedProcess, - get_chunk_slices, +from sup3r.utilities.utilities import (get_chunk_slices, get_source_type, get_input_handler_class) +from sup3r.utilities.execution import DistributedProcess from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI diff --git a/sup3r/utilities/execution.py b/sup3r/utilities/execution.py new file mode 100644 index 000000000..96e3929b6 --- /dev/null +++ b/sup3r/utilities/execution.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +"""Execution methods for running some cli routines + +@author: bbenton +""" + +import numpy as np +import logging +import os + +logger = logging.getLogger(__name__) + + +class DistributedProcess: + """High-level class with commonly used functionality for processes + distributed across multiple nodes""" + + def __init__(self, max_nodes=1, n_chunks=None, max_chunks=None, + incremental=False): + """ + Parameters + ---------- + max_nodes : int, optional + Max number of nodes to distribute processes across + n_chunks : int, optional + Number of chunks to split all processes into. These process + chunks will be distributed across nodes. + max_chunks : int, optional + Max number of chunks processes can be split into. + incremental : bool + Whether to skip previously run process chunks or to overwrite. + """ + msg = ('For a distributed process either max_chunks or ' + 'max_chunks + n_chunks must be specified. Received ' + f'max_chunks={max_chunks}, n_chunks={n_chunks}.') + assert max_chunks is not None, msg + self._node_chunks = None + self._n_chunks = n_chunks + self._max_nodes = max_nodes + self._max_chunks = max_chunks + self._out_files = None + self._failed_chunks = False + self.incremental = incremental + + def __len__(self): + """Get total number of process chunks""" + return self.chunks + + def node_finished(self, node_index): + """Check if all out files for a given node have been saved + + Parameters + ---------- + node_index : int + Index of node to check for completed processes + + Returns + ------- + bool + Whether all processes for the given node have finished + """ + return all(self.chunk_finished(i) + for i in self.node_chunks[node_index]) + + # pylint: disable=E1136 + def chunk_finished(self, chunk_index): + """Check if process for given chunk_index has already been run. + + Parameters + ---------- + chunk_index : int + Index of the process chunk to check for completion. Considered + finished if there is already an output file and incremental is + False. + + Returns + ------- + bool + Whether the process for the given chunk has finished + """ + out_file = self.out_files[chunk_index] + if os.path.exists(out_file) and self.incremental: + logger.info('Not running chunk index {}, output file ' + 'exists: {}'.format(chunk_index, out_file)) + return True + return False + + @property + def all_finished(self): + """Check if all out files have been saved""" + return all(self.node_finished(i) for i in range(self.nodes)) + + @property + def out_files(self): + """Get list of out files to write process output to""" + return self._out_files + + @property + def max_nodes(self): + """Get uncapped max number of nodes to distribute processes across""" + return self._max_nodes + + @property + def chunks(self): + """Get the number of process chunks for this distributed routine.""" + if self._n_chunks is None: + return self._max_chunks + else: + return min(self._n_chunks, self._max_chunks) + + @property + def nodes(self): + """Get the max number of nodes to distribute chunks across, limited by + the number of process chunks""" + return len(self.node_chunks) + + @property + def node_chunks(self): + """Get the chunk indices for different nodes""" + if self._node_chunks is None: + n_chunks = min(self.max_nodes, self.chunks) + self._node_chunks = np.array_split(np.arange(self.chunks), + n_chunks) + return self._node_chunks + + @property + def failed_chunks(self): + """Check whether any processes have failed.""" + return self._failed_chunks + + @failed_chunks.setter + def failed_chunks(self, failed): + """Set failed_chunks value. Should be set to True if there is a failed + chunk""" + self._failed_chunks = failed diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index f41cc2b9e..5accf39f0 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -15,7 +15,7 @@ from sup3r.postprocessing.file_handling import OutputMixIn, RexOutputs from sup3r.utilities import ModuleName -from sup3r.utilities.utilities import DistributedProcess +from sup3r.utilities.execution import DistributedProcess from sup3r.utilities.cli import BaseCLI logger = logging.getLogger(__name__) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 5c8da69d9..69046c0f9 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- -"""Utilities module for preparing -training data +"""Utilities module for preparing training data @author: bbenton """ @@ -26,130 +25,6 @@ logger = logging.getLogger(__name__) -class DistributedProcess: - """High-level class with commonly used functionality for processes - distributed across multiple nodes""" - - def __init__(self, max_nodes=1, n_chunks=None, max_chunks=None, - incremental=False): - """ - Parameters - ---------- - max_nodes : int, optional - Max number of nodes to distribute processes across - n_chunks : int, optional - Number of chunks to split all processes into. These process - chunks will be distributed across nodes. - max_chunks : int, optional - Max number of chunks processes can be split into. - incremental : bool - Whether to skip previously run process chunks or to overwrite. - """ - msg = ('For a distributed process either max_chunks or ' - 'max_chunks + n_chunks must be specified. Received ' - f'max_chunks={max_chunks}, n_chunks={n_chunks}.') - assert max_chunks is not None, msg - self._node_chunks = None - self._n_chunks = n_chunks - self._max_nodes = max_nodes - self._max_chunks = max_chunks - self._out_files = None - self._failed_chunks = False - self.incremental = incremental - - def __len__(self): - """Get total number of process chunks""" - return self.chunks - - def node_finished(self, node_index): - """Check if all out files for a given node have been saved - - Parameters - ---------- - node_index : int - Index of node to check for completed processes - - Returns - ------- - bool - Whether all processes for the given node have finished - """ - return all(self.chunk_finished(i) - for i in self.node_chunks[node_index]) - - # pylint: disable=E1136 - def chunk_finished(self, chunk_index): - """Check if process for given chunk_index has already been run. - - Parameters - ---------- - chunk_index : int - Index of the process chunk to check for completion. Considered - finished if there is already an output file and incremental is - False. - - Returns - ------- - bool - Whether the process for the given chunk has finished - """ - out_file = self.out_files[chunk_index] - if os.path.exists(out_file) and self.incremental: - logger.info('Not running chunk index {}, output file ' - 'exists: {}'.format(chunk_index, out_file)) - return True - return False - - @property - def all_finished(self): - """Check if all out files have been saved""" - return all(self.node_finished(i) for i in range(self.nodes)) - - @property - def out_files(self): - """Get list of out files to write process output to""" - return self._out_files - - @property - def max_nodes(self): - """Get uncapped max number of nodes to distribute processes across""" - return self._max_nodes - - @property - def chunks(self): - """Get the number of processes chunks for this distributed routine.""" - if self._n_chunks is None: - return self._max_chunks - else: - return min(self._n_chunks, self._max_chunks) - - @property - def nodes(self): - """Get the max number of nodes to distribute chunks across, limited by - the number of process chunks""" - return len(self.node_chunks) - - @property - def node_chunks(self): - """Get the chunk indices for different nodes""" - if self._node_chunks is None: - n_chunks = min(self.max_nodes, self.chunks) - self._node_chunks = np.array_split(np.arange(self.chunks), - n_chunks) - return self._node_chunks - - @property - def failed_chunks(self): - """Check whether any processes have failed.""" - return self._failed_chunks - - @failed_chunks.setter - def failed_chunks(self, failed): - """Set failed_chunks value. Should be set to True if there is a failed - chunk""" - self._failed_chunks = failed - - def correct_path(path): """If running on windows we need to replace backslashes with double backslashes so paths can be parsed correctly with safe_open_json"""