diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 8550bcede..8e66c536e 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -541,24 +541,13 @@ class ForwardPassStrategy(InputMixIn): def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, spatial_pad, temporal_pad, - temporal_slice=slice(None), model_class='Sup3rGan', - target=None, shape=None, - raster_file=None, - time_chunk_size=None, - cache_pattern=None, out_pattern=None, - overwrite_cache=False, - overwrite_ti_cache=False, input_handler=None, input_handler_kwargs=None, incremental=True, max_workers=None, - extract_workers=None, - compute_workers=None, - load_workers=None, output_workers=None, - ti_workers=None, exo_kwargs=None, pass_workers=1, bias_correct_method=None, @@ -694,8 +683,8 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, pass bias_correct_method : str | None Optional bias correction function name that can be imported from - the sup3r.bias.bias_transforms module. This will transform the - source data according to some predefined bias correction + the :mod:`sup3r.bias.bias_transforms` module. This will transform + the source data according to some predefined bias correction transformation along with the bias_correct_kwargs. As the first argument, this method must receive a generic numpy array of data to be bias corrected @@ -717,28 +706,16 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self.temporal_pad = temporal_pad self.model_class = model_class self.out_pattern = out_pattern - self.raster_file = raster_file - self.temporal_slice = temporal_slice - self.time_chunk_size = time_chunk_size - self.overwrite_cache = overwrite_cache - self.overwrite_ti_cache = overwrite_ti_cache self.max_workers = max_workers - self.extract_workers = extract_workers - self.compute_workers = compute_workers - self.load_workers = load_workers self.output_workers = output_workers self.pass_workers = pass_workers self.exo_kwargs = exo_kwargs or {} self.incremental = incremental - self.ti_workers = ti_workers self._single_time_step_files = None - self._cache_pattern = cache_pattern self._input_handler_class = None self._input_handler_name = input_handler self._max_nodes = max_nodes self._input_handler_kwargs = input_handler_kwargs or {} - self._grid_shape = shape - self._target = target self._time_index = None self._raw_time_index = None self._out_files = None @@ -749,6 +726,26 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape, self.bias_correct_method = bias_correct_method self.bias_correct_kwargs = bias_correct_kwargs or {} + self._target = self._input_handler_kwargs.get('target', None) + self._grid_shape = self._input_handler_kwargs.get('shape', None) + self.raster_file = self._input_handler_kwargs.get('raster_file', None) + self.temporal_slice = self._input_handler_kwargs.get('temporal_slice', + slice(None)) + self.time_chunk_size = self._input_handler_kwargs.get( + 'time_chunk_size', None) + self.overwrite_cache = self._input_handler_kwargs.get( + 'overwrite_cache', False) + self.overwrite_ti_cache = self._input_handler_kwargs.get( + 'overwrite_ti_cache', False) + self.extract_workers = self._input_handler_kwargs.get( + 'extract_workers', None) + self.compute_workers = self._input_handler_kwargs.get( + 'compute_workers', None) + self.load_workers = self._input_handler_kwargs.get('load_workers', + None) + self.ti_workers = self._input_handler_kwargs.get('ti_workers', None) + self._cache_pattern = self._input_handler_kwargs.get('cache_pattern', + None) self.cap_worker_args(max_workers) model_class = getattr(sup3r.models, self.model_class, None) @@ -1088,7 +1085,12 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.output_features = self.model.output_features self.meta_data = self.model.meta - self.get_chunk_kwargs(strategy, chunk_index) + self._file_paths = strategy.file_paths + self.max_workers = strategy.max_workers + self.pass_workers = strategy.pass_workers + self.output_workers = strategy.output_workers + self.exo_kwargs = strategy.exo_kwargs + self.single_time_step_files = strategy.single_time_step_files self.exogenous_handler = None self.exogenous_data = None @@ -1119,7 +1121,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0): elif strategy.output_type == 'h5': self.output_handler_class = OutputHandlerH5 - input_handler_kwargs = dict( + input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs) + fwp_input_handler_kwargs = dict( file_paths=self.file_paths, features=self.features, target=self.target, @@ -1130,14 +1133,13 @@ def __init__(self, strategy, chunk_index=0, node_index=0): time_chunk_size=self.strategy.time_chunk_size, overwrite_cache=self.strategy.overwrite_cache, max_workers=self.max_workers, - extract_workers=self.extract_workers, - compute_workers=self.compute_workers, - load_workers=self.load_workers, - ti_workers=self.ti_workers, + extract_workers=strategy.extract_workers, + compute_workers=strategy.compute_workers, + load_workers=strategy.load_workers, + ti_workers=strategy.ti_workers, handle_features=self.strategy.handle_features, val_split=0.0) - - input_handler_kwargs.update(self.strategy._input_handler_kwargs) + input_handler_kwargs.update(fwp_input_handler_kwargs) self.data_handler = self.input_handler_class(**input_handler_kwargs) self.data_handler.load_cached_data() self.input_data = self.data_handler.data @@ -1145,10 +1147,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data = self.bias_correct_source_data(self.input_data) exo_s_en = self.exo_kwargs.get('s_enhancements', None) - out = self.pad_source_data(self.input_data, - self.pad_s1_start, self.pad_s1_end, - self.pad_s2_start, self.pad_s2_end, - self.pad_t_start, self.pad_t_end, + pad_width = self.get_padding() + out = self.pad_source_data(self.input_data, pad_width, self.exogenous_data, exo_s_en) self.input_data, self.exogenous_data = out @@ -1158,14 +1158,14 @@ def file_paths(self): reduced if there are single timesteps per file.""" file_paths = self._file_paths if self.single_time_step_files: - file_paths = self._file_paths[self._ti_pad_slice] + file_paths = self._file_paths[self.ti_pad_slice] return file_paths @property def temporal_pad_slice(self): """Get the low resolution temporal slice including padding.""" - ti_pad_slice = self._ti_pad_slice + ti_pad_slice = self.ti_pad_slice if self.single_time_step_files: ti_pad_slice = slice(None) @@ -1212,100 +1212,120 @@ def temporal_chunk_index(self): """Temporal index for the current chunk going through forward pass""" return self.chunk_index // self.strategy.fwp_slicer.n_spatial_chunks - def get_chunk_kwargs(self, strategy, chunk_index): - """Get node specific variables given an associated index + @property + def out_file(self): + """Get output file name for the current chunk""" + return self.strategy.out_files[self.chunk_index] - Parameters - ---------- - strategy : ForwardPassStrategy - ForwardPassStrategy instance with information on data chunks to run - forward passes on. - chunk_index : int - Index to select chunk specific variables. This index selects the - corresponding file set, cropped_file_slice, padded_file_slice, - and padded/overlapping/cropped spatial slice for a spatiotemporal - chunk - """ + @property + def ti_slice(self): + """Get ti slice for the current chunk""" + return self.strategy.ti_slices[self.temporal_chunk_index] + + @property + def ti_pad_slice(self): + """Get padded ti slice for the current chunk""" + return self.strategy.ti_pad_slices[self.temporal_chunk_index] + + @property + def lr_slice(self): + """Get lr slice for the current chunk""" + return self.strategy.lr_slices[self.spatial_chunk_index] - if chunk_index >= len(strategy): - msg = ('Index is out of bounds. There are ' - f'{len(strategy.file_ids)} file chunks and the index ' - f'requested was {chunk_index}.') - raise ValueError(msg) + @property + def lr_pad_slice(self): + """Get padded lr slice for the current chunk""" + return self.strategy.lr_pad_slices[self.spatial_chunk_index] - s_chunk_index = self.spatial_chunk_index - t_chunk_index = self.temporal_chunk_index + @property + def hr_slice(self): + """Get hr slice for the current chunk""" + return self.strategy.hr_slices[self.spatial_chunk_index] - self.out_file = strategy.out_files[chunk_index] - self._ti_pad_slice = strategy.ti_pad_slices[t_chunk_index] - self.ti_slice = strategy.ti_slices[t_chunk_index] - hr_crop_slices = strategy.fwp_slicer.hr_crop_slices[t_chunk_index] + @property + def hr_crop_slice(self): + """Get hr cropping slice for the current chunk""" + hr_crop_slices = self.strategy.fwp_slicer.hr_crop_slices[ + self.temporal_chunk_index] + return hr_crop_slices[self.spatial_chunk_index] - self.cache_pattern = strategy.cache_pattern - if self.cache_pattern is not None: - self.cache_pattern = self.cache_pattern.replace( - '{temporal_chunk_index}', str(t_chunk_index)) - self.cache_pattern = self.cache_pattern.replace( - '{spatial_chunk_index}', str(s_chunk_index)) + @property + def lr_crop_slice(self): + """Get lr cropping slice for the current chunk""" + lr_crop_slices = self.strategy.fwp_slicer.s_lr_crop_slices + return lr_crop_slices[self.spatial_chunk_index] - self.raster_file = strategy.raster_file - if self.raster_file is not None: - self.raster_file = self.raster_file.replace( - '{spatial_chunk_index}', str(s_chunk_index)) - - self.ti_start = self.ti_slice.start or 0 - self.ti_stop = self.ti_slice.stop or len(strategy.raw_time_index) - self.pad_t_start = int(np.maximum(0, (strategy.temporal_pad - - self.ti_start))) - self.pad_t_end = int(np.maximum(0, (strategy.temporal_pad - + self.ti_stop - - len(strategy.raw_time_index)))) - - self.lr_slice = strategy.lr_slices[s_chunk_index] - self.lr_pad_slice = strategy.lr_pad_slices[s_chunk_index] - self.hr_slice = strategy.hr_slices[s_chunk_index] - self.hr_crop_slice = hr_crop_slices[s_chunk_index] - lr_crop_slices = strategy.fwp_slicer.s_lr_crop_slices - self.lr_crop_slice = lr_crop_slices[s_chunk_index] - - self.s1_start = self.lr_slice[0].start or 0 - self.s1_stop = self.lr_slice[0].stop or strategy.grid_shape[0] - self.pad_s1_start = int(np.maximum(0, (strategy.spatial_pad - - self.s1_start))) - self.pad_s1_end = int(np.maximum(0, (strategy.spatial_pad - + self.s1_stop - - strategy.grid_shape[0]))) - - self.s2_start = self.lr_slice[1].start or 0 - self.s2_stop = self.lr_slice[1].stop or strategy.grid_shape[1] - self.pad_s2_start = int(np.maximum(0, (strategy.spatial_pad - - self.s2_start))) - self.pad_s2_end = int(np.maximum(0, (strategy.spatial_pad - + self.s2_stop - - strategy.grid_shape[1]))) - - self.data_shape = (*strategy.grid_shape, - len(strategy.raw_time_index[self._ti_pad_slice])) - - self.chunk_shape = ( - self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, - self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, - self.data_shape[2]) + @property + def data_shape(self): + """Get data shape for the current padded temporal chunk""" + return (*self.strategy.grid_shape, + len(self.strategy.raw_time_index[self.ti_pad_slice])) - self._file_paths = strategy.file_paths - self.max_workers = strategy.max_workers - self.pass_workers = strategy.pass_workers - self.ti_workers = strategy.ti_workers - self.extract_workers = strategy.extract_workers - self.compute_workers = strategy.compute_workers - self.load_workers = strategy.load_workers - self.output_workers = strategy.output_workers - self.exo_kwargs = strategy.exo_kwargs - self.single_time_step_files = strategy.single_time_step_files + @property + def chunk_shape(self): + """Get shape for the current padded spatiotemporal chunk""" + return (self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start, + self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start, + self.data_shape[2]) + + @property + def cache_pattern(self): + """Get cache pattern for the current chunk""" + cache_pattern = self.strategy.cache_pattern + if cache_pattern is not None: + cache_pattern = cache_pattern.replace( + '{temporal_chunk_index}', str(self.temporal_chunk_index)) + cache_pattern = cache_pattern.replace( + '{spatial_chunk_index}', str(self.spatial_chunk_index)) + return cache_pattern + + @property + def raster_file(self): + """Get raster file for the current spatial chunk""" + raster_file = self.strategy.raster_file + if raster_file is not None: + raster_file = raster_file.replace( + '{spatial_chunk_index}', str(self.spatial_chunk_index)) + return raster_file + + def get_padding(self): + """Get padding for the current spatiotemporal chunk + + Returns + ------- + padding : tuple + Tuple of tuples with padding width for spatial and temporal + dimensions. Each tuple includes the start and end of padding for + that dimension. Ordering is spatial_1, spatial_2, temporal. + """ + ti_start = self.ti_slice.start or 0 + ti_stop = self.ti_slice.stop or len(self.strategy.raw_time_index) + pad_t_start = int(np.maximum(0, (self.strategy.temporal_pad + - ti_start))) + pad_t_end = int(np.maximum(0, (self.strategy.temporal_pad + + ti_stop + - len(self.strategy.raw_time_index)))) + + s1_start = self.lr_slice[0].start or 0 + s1_stop = self.lr_slice[0].stop or self.strategy.grid_shape[0] + pad_s1_start = int(np.maximum(0, (self.strategy.spatial_pad + - s1_start))) + pad_s1_end = int(np.maximum(0, (self.strategy.spatial_pad + + s1_stop + - self.strategy.grid_shape[0]))) + + s2_start = self.lr_slice[1].start or 0 + s2_stop = self.lr_slice[1].stop or self.strategy.grid_shape[1] + pad_s2_start = int(np.maximum(0, (self.strategy.spatial_pad + - s2_start))) + pad_s2_end = int(np.maximum(0, (self.strategy.spatial_pad + + s2_stop + - self.strategy.grid_shape[1]))) + return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), + (pad_t_start, pad_t_end)) @staticmethod - def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start, - pad_s2_end, pad_t_start, pad_t_end, exo_data, + def pad_source_data(input_data, pad_width, exo_data, exo_s_enhancements, mode='reflect'): """Pad the edges of the source data from the data handler. @@ -1319,20 +1339,10 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start, passes for subsequent spatial stitching. This overlap will pad both sides of the fwp_chunk_shape. Note that the first and last chunks in any of the spatial dimension will not be padded. - pad_s1_start : int - How much padding to add to the beginning of the first spatial - dimension. - pad_s1_end : bool - How much padding to add to the end of the first spatial dimension. - pad_s2_start : int - How much padding to add to the beginning of the second spatial - dimension. - pad_s2_end : bool - How much padding to add to the end of the second spatial dimension. - pad_t_start : int - How much padding to add to the beginning of the temporal axis. - pad_t_end : bool - How much padding to add to the end of the temporal axis. + pad_width : tuple + Tuple of tuples with padding width for spatial and temporal + dimensions. Each tuple includes the start and end of padding for + that dimension. Ordering is spatial_1, spatial_2, temporal. exo_data : None | list List of exogenous data arrays for each step of the sup3r resolution model. List entries can be None if not exo data is requested for a @@ -1352,10 +1362,7 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start, exo_data : list | None Padded copy of exo_data input. """ - - pad_width = ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), - (pad_t_start, pad_t_end), (0, 0)) - out = np.pad(input_data, pad_width, mode=mode) + out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) logger.info('Padded input data shape from {} to {} using mode "{}" ' 'with padding argument: {}' @@ -1368,11 +1375,12 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start, total_s_enhance = [s for s in total_s_enhance if s is not None] total_s_enhance = np.product(total_s_enhance) - pad_width = ((total_s_enhance * pad_s1_start, - total_s_enhance * pad_s1_end), - (total_s_enhance * pad_s2_start, - total_s_enhance * pad_s2_end), (0, 0)) - exo_data[i] = np.pad(i_exo_data, pad_width, mode=mode) + exo_pad_width = ((total_s_enhance * pad_width[0][0], + total_s_enhance * pad_width[0][1]), + (total_s_enhance * pad_width[1][0], + total_s_enhance * pad_width[1][1]), + (0, 0)) + exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) return out, exo_data diff --git a/sup3r/pipeline/pipeline.py b/sup3r/pipeline/pipeline.py index 1c350abad..037a82355 100644 --- a/sup3r/pipeline/pipeline.py +++ b/sup3r/pipeline/pipeline.py @@ -3,16 +3,12 @@ Sup3r data pipeline architecture. """ import logging -import os -import json from reV.pipeline.pipeline import Pipeline -from rex.utilities.loggers import init_logger, create_dirs +from rex.utilities.loggers import init_logger from sup3r.pipeline.config import Sup3rPipelineConfig from sup3r.utilities import ModuleName -from sup3r.models.base import Sup3rGan -from sup3r.postprocessing.file_handling import OutputHandlerH5 logger = logging.getLogger(__name__) @@ -48,80 +44,3 @@ def __init__(self, pipeline, monitor=True, verbose=False): if 'logging' in self._config: init_logger('sup3r.pipeline', **self._config.logging) init_logger('reV.pipeline', **self._config.logging) - - @classmethod - def init_pass_collect(cls, out_dir, file_paths, model_path, - fwp_kwargs=None, dc_kwargs=None): - """Generate config files for forward pass and collection - - Parameters - ---------- - out_dir : str - Parent directory for pipeline run. - file_paths : str | list - A single source h5 wind file to extract raster data from or a list - of netcdf files with identical grid. The string can be a unix-style - file path which will be passed through glob.glob - model_path : str - Path to gan model for forward pass - fwp_kwargs : dict - Dictionary of keyword args passed to the ForwardPassStrategy class. - dc_kwargs : dict - Dictionary of keyword args passed to the Collection.collect() - method. - """ - fwp_kwargs = fwp_kwargs or {} - dc_kwargs = dc_kwargs or {} - logger.info('Generating config files for forward pass and data ' - 'collection') - log_dir = os.path.join(out_dir, 'logs/') - output_dir = os.path.join(out_dir, 'output/') - cache_dir = os.path.join(out_dir, 'cache/') - std_out_dir = os.path.join(out_dir, 'stdout/') - all_dirs = [out_dir, log_dir, cache_dir, output_dir, std_out_dir] - for d in all_dirs: - create_dirs(d) - out_pattern = os.path.join(output_dir, 'fwp_out_{file_id}.h5') - cache_pattern = os.path.join(cache_dir, 'cache_{feature}.pkl') - log_pattern = os.path.join(log_dir, 'log_{node_index}.log') - model_params = Sup3rGan.load_saved_params(model_path, - verbose=False)['meta'] - features = model_params['output_features'] - features = OutputHandlerH5.get_renamed_features(features) - fwp_config = {'file_paths': file_paths, - 'model_args': model_path, - 'out_pattern': out_pattern, - 'cache_pattern': cache_pattern, - 'log_pattern': log_pattern} - fwp_config.update(fwp_kwargs) - fwp_config_file = os.path.join(out_dir, 'config_fwp.json') - with open(fwp_config_file, 'w') as f: - json.dump(fwp_config, f, sort_keys=True, indent=4) - logger.info(f'Saved forward-pass config file: {fwp_config_file}.') - - collect_file = os.path.join(output_dir, 'out_collection.h5') - log_file = os.path.join(log_dir, 'collect.log') - input_files = os.path.join(output_dir, 'fwp_out_*.h5') - col_config = {'file_paths': input_files, - 'out_file': collect_file, - 'features': features, - 'log_file': log_file} - col_config.update(dc_kwargs) - col_config_file = os.path.join(out_dir, 'config_collect.json') - with open(col_config_file, 'w') as f: - json.dump(col_config, f, sort_keys=True, indent=4) - logger.info(f'Saved data-collect config file: {col_config_file}.') - - pipe_config = {'logging': {'log_level': 'DEBUG'}, - 'pipeline': [{'forward-pass': fwp_config_file}, - {'data-collect': col_config_file}]} - pipeline_file = os.path.join(out_dir, 'config_pipeline.json') - with open(pipeline_file, 'w') as f: - json.dump(pipe_config, f, sort_keys=True, indent=4) - logger.info(f'Saved pipeline config file: {pipeline_file}.') - - script_file = os.path.join(out_dir, 'run.sh') - with open(script_file, 'w') as f: - cmd = 'python -m sup3r.cli -c config_pipeline.json pipeline' - f.write(cmd) - logger.info(f'Saved script file: {script_file}.') diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 294098b64..739e88adb 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -326,9 +326,7 @@ def _get_collection_attrs(cls, file_paths, feature, sort=True, ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.h5. feature : str Dataset name to collect. sort : bool @@ -341,7 +339,8 @@ def _get_collection_attrs(cls, file_paths, feature, sort=True, None will use all available workers. target_final_meta_file : str Path to target final meta containing coordinates to keep from the - full file list collected meta + full list of coordinates present in the collected meta for the full + file list. Returns ------- @@ -466,9 +465,9 @@ def _ensure_dset_in_output(out_file, dset, data=None): attrs=attrs, data=data, chunks=attrs.get('chunks', None)) - def _collect_flist(self, feature, masked_meta, time_index, shape, - file_paths, out_file, target_final_meta, - masked_target_meta, max_workers=None): + def _collect_flist(self, feature, subset_masked_meta, time_index, shape, + file_paths, out_file, target_masked_meta, + max_workers=None): """Collect a dataset from a file list without getting attributes first. This file list can be a subset of a full file list to be collected. @@ -476,9 +475,12 @@ def _collect_flist(self, feature, masked_meta, time_index, shape, ---------- feature : str Dataset name to collect. - masked_meta : pd.DataFrame - Concatenated meta data for the given file paths. This masked - against the target_final_meta. + subset_masked_meta : pd.DataFrame + Meta data containing the list of coordinates present in both the + given file paths and the target_final_meta. This can be a subset of + the coordinates present in the full file list. The coordinates + contained in this dataframe have the same gids as those present in + the meta for the full file list. time_index : pd.datetimeindex Concatenated datetime index for the given file paths. shape : tuple @@ -488,17 +490,14 @@ def _collect_flist(self, feature, masked_meta, time_index, shape, to be collected. out_file : str File path of final output file. - target_final_meta : str - Target final meta containing coordinates to keep from the - full file list collected meta - masked_target_meta : pd.DataFrame - Collected meta data with mask applied from target_final_meta so - original gids are preserved. + target_masked_meta : pd.DataFrame + Same as subset_masked_meta but instead for the entire list of files + to be collected. max_workers : int | None Number of workers to use in parallel. 1 runs serial, None uses all available. """ - if len(masked_meta) > 0: + if len(subset_masked_meta) > 0: attrs, final_dtype = get_dset_attrs(feature) scale_factor = attrs.get('scale_factor', 1) @@ -517,8 +516,9 @@ def _collect_flist(self, feature, masked_meta, time_index, shape, for i, fname in enumerate(file_paths): logger.debug('Collecting data from file {} out of {}.' .format(i + 1, len(file_paths))) - self.get_data(fname, feature, time_index, masked_meta, - scale_factor, final_dtype) + self.get_data(fname, feature, time_index, + subset_masked_meta, scale_factor, + final_dtype) else: logger.info('Running parallel collection on {} workers.' .format(max_workers)) @@ -528,7 +528,7 @@ def _collect_flist(self, feature, masked_meta, time_index, shape, with ThreadPoolExecutor(max_workers=max_workers) as exe: for fname in file_paths: future = exe.submit(self.get_data, fname, feature, - time_index, masked_meta, + time_index, subset_masked_meta, scale_factor, final_dtype) futures[future] = fname for future in as_completed(futures): @@ -547,15 +547,11 @@ def _collect_flist(self, feature, masked_meta, time_index, shape, msg += f'{futures[future]}' logger.exception(msg) raise RuntimeError(msg) from e - if not os.path.exists(out_file): - Collector._init_collected_h5(out_file, time_index, - target_final_meta) - x_write_slice, y_write_slice = slice(None), slice(None) - else: - with RexOutputs(out_file, 'r') as f: - target_ti = f.time_index - y_write_slice, x_write_slice = Collector.get_slices( - target_ti, masked_target_meta, time_index, masked_meta) + with RexOutputs(out_file, 'r') as f: + target_ti = f.time_index + y_write_slice, x_write_slice = Collector.get_slices( + target_ti, target_masked_meta, time_index, + subset_masked_meta) Collector._ensure_dset_in_output(out_file, feature) with RexOutputs(out_file, mode='a') as f: f[feature, y_write_slice, x_write_slice] = self.data @@ -641,9 +637,7 @@ def collect(cls, file_paths, out_file, features, max_workers=None, ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.h5. out_file : str File path of final output file. features : list @@ -666,7 +660,13 @@ def collect(cls, file_paths, out_file, features, max_workers=None, a suffix format _{temporal_chunk_index}_{spatial_chunk_index}.h5 target_final_meta_file : str Path to target final meta containing coordinates to keep from the - full file list collected meta + full file list collected meta. This can be but is not necessarily a + subset of the full list of coordinates for all files in the file + list. This is used to remove coordinates from the full file list + which are not present in the target_final_meta. Either this full + meta or a subset, depending on which coordinates are present in + the data to be collected, will be the final meta for the collected + output files. n_writes : int | None Number of writes to split full file list into. Must be less than or equal to the number of temporal chunks. @@ -696,17 +696,17 @@ def collect(cls, file_paths, out_file, features, max_workers=None, out = collector._get_collection_attrs( collector.flist, dset, max_workers=max_workers, target_final_meta_file=target_final_meta_file) - time_index, final_target_meta, masked_target_meta = out[:3] + time_index, target_final_meta, target_masked_meta = out[:3] shape, _, global_attrs = out[3:] if not os.path.exists(out_file): collector._init_collected_h5(out_file, time_index, - final_target_meta, global_attrs) + target_final_meta, global_attrs) if len(flist_chunks) == 1: - collector._collect_flist(dset, masked_target_meta, time_index, + collector._collect_flist(dset, target_masked_meta, time_index, shape, flist_chunks[0], out_file, - final_target_meta, masked_target_meta, + target_masked_meta, max_workers=max_workers) else: @@ -719,8 +719,7 @@ def collect(cls, file_paths, out_file, features, max_workers=None, target_final_meta_file=target_final_meta_file) collector._collect_flist(dset, masked_meta, time_index, shape, flist, out_file, - target_final_meta, - masked_target_meta, + target_masked_meta, max_workers=max_workers) if write_status and job_name is not None: diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index 5075ac63c..4a9ebf98e 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -532,5 +532,5 @@ def _write_output(cls, data, features, lat_lon, times, out_file, if meta_data is not None: fh.run_attrs = {'gan_meta': json.dumps(meta_data)} - os.rename(tmp_file, out_file) + os.replace(tmp_file, out_file) logger.info(f'Saved output of size {data.shape} to: {out_file}') diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index ac728f1c2..0f0427384 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -1285,7 +1285,7 @@ def cache_data(self, cache_file_paths): tmp_file = fp.replace('.pkl', '.pkl.tmp') with open(tmp_file, 'wb') as fh: pickle.dump(self.data[..., i], fh, protocol=4) - os.rename(tmp_file, fp) + os.replace(tmp_file, fp) else: msg = (f'Called cache_data but {fp} already exists. Set to ' 'overwrite_cache to True to overwrite.') diff --git a/sup3r/qa/stats.py b/sup3r/qa/stats.py index 696560952..514b170c7 100644 --- a/sup3r/qa/stats.py +++ b/sup3r/qa/stats.py @@ -138,11 +138,11 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance, get_lr : bool Whether to include low resolution stats in output include_stats : list | None - List of stats to include in output. e.g. ['ws_ramp_rate', + List of stats to include in output. e.g. ['ramp_rate', 'velocity_grad', 'vorticity', 'tke_avg_k', 'tke_avg_f', 'tke_ts'] max_values : dict | None Dictionary of max values to keep for stats. e.g. - {'ws_ramp_rate_max': 10, 'v_grad_max': 14, 'vorticity_max': 7} + {'ramp_rate_max': 10, 'v_grad_max': 14, 'vorticity_max': 7} ramp_rate_t_step : int | list Number of time steps to use for ramp rate calculation. If low res data is hourly then t_step=1 will calculate the hourly ramp rate. @@ -168,16 +168,16 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance, ti_workers = max_workers self.max_values = max_values or {} - self.ws_ramp_rate_max = self.max_values.get('ws_ramp_rate_max', 10) + self.ramp_rate_max = self.max_values.get('ramp_rate_max', 10) self.v_grad_max = self.max_values.get('v_grad_max', 7) self.vorticity_max = self.max_values.get('vorticity_max', 14) self.ramp_rate_t_step = (ramp_rate_t_step if isinstance(ramp_rate_t_step, list) else [ramp_rate_t_step]) - self.include_stats = include_stats or ['ws_ramp_rate', 'velocity_grad', + self.include_stats = include_stats or ['ramp_rate', 'velocity_grad', 'vorticity', 'tke_avg_k', 'tke_avg_f', 'tke_ts', - 'mean_ws_ramp_rate'] + 'mean_ramp_rate'] self.s_enhance = s_enhance self.t_enhance = t_enhance @@ -761,18 +761,18 @@ def get_ramp_rate_stats(self, u, v, scale=1): """ stats_dict = {} - if 'ws_ramp_rate' in self.include_stats: + if 'ramp_rate' in self.include_stats: for i, time in enumerate(self.ramp_rate_t_step): logger.info('Computing ramp rate pdf.') - out = ws_ramp_rate_dist(u, v, diff_max=self.ws_ramp_rate_max, + out = ws_ramp_rate_dist(u, v, diff_max=self.ramp_rate_max, t_steps=time, scale=scale) stats_dict[f'ramp_rate_{self.ramp_rate_t_step[i]}'] = out - if 'mean_ws_ramp_rate' in self.include_stats: + if 'mean_ramp_rate' in self.include_stats: for i, time in enumerate(self.ramp_rate_t_step): logger.info('Computing mean ramp rate pdf.') out = ws_ramp_rate_dist(np.mean(u, axis=(0, 1)), np.mean(v, axis=(0, 1)), - diff_max=self.ws_ramp_rate_max, + diff_max=self.ramp_rate_max, t_steps=time, scale=scale) stats_dict[f'mean_ramp_rate_{self.ramp_rate_t_step[i]}'] = out return stats_dict @@ -879,7 +879,7 @@ def run(self): stats : dict Dictionary of statistics, where keys are lr/hr/interp appended with the height of the corresponding wind field. Values are dictionaries - of statistics, such as velocity_gradient, vorticity, ws_ramp_rate, + of statistics, such as velocity_gradient, vorticity, ramp_rate, etc """ diff --git a/sup3r/qa/utilities.py b/sup3r/qa/utilities.py index 9a94f67af..dd3f283ff 100644 --- a/sup3r/qa/utilities.py +++ b/sup3r/qa/utilities.py @@ -205,7 +205,7 @@ def ws_ramp_rate_dist(u, v, bins=50, range=None, diff_max=10, t_steps=1, msg = (f'Received t_steps={t_steps} for ramp rate calculation but data ' f'only has {u.shape[-1]} time steps') assert t_steps < u.shape[-1], msg - ws = np.sqrt(u**2 + v**2) + ws = np.hypot(u, v) diffs = (ws[..., t_steps:] - ws[..., :-t_steps]).flatten() diffs /= scale diffs = diffs[(np.abs(diffs) < diff_max)] diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 3efba5283..753a84bb8 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -464,7 +464,6 @@ def transform_rotate_wind(ws, wd, lat_lon): # calculate the angle from the vertical theta = (np.pi / 2) - np.arctan2(dy, dx) - del dy, dx theta[0] = theta[1] # fix the roll row wd = np.radians(wd) @@ -477,7 +476,6 @@ def transform_rotate_wind(ws, wd, lat_lon): if invert_lat: u_rot = u_rot[::-1] v_rot = v_rot[::-1] - del theta, ws, wd return u_rot, v_rot @@ -507,7 +505,12 @@ def invert_uv(u, v, lat_lon): measured relative to the south_north direction. (spatial_1, spatial_2, temporal) """ - # get the dy/dx to the nearest vertical neighbor + invert_lat = False + if lat_lon[-1, 0, 0] > lat_lon[0, 0, 0]: + invert_lat = True + lat_lon = lat_lon[::-1] + u = u[::-1] + v = v[::-1] dy = lat_lon[:, :, 0] - np.roll(lat_lon[:, :, 0], 1, axis=0) dx = lat_lon[:, :, 1] - np.roll(lat_lon[:, :, 1], 1, axis=0) dy = (dy + 90) % 180 - 90 @@ -523,14 +526,12 @@ def invert_uv(u, v, lat_lon): v_rot = np.sin(theta)[:, :, np.newaxis] * u v_rot += np.cos(theta)[:, :, np.newaxis] * v - ws = np.sqrt(u_rot**2 + v_rot**2) + ws = np.hypot(u_rot, v_rot) wd = (np.degrees(np.arctan2(u_rot, v_rot)) + 360) % 360 - # if lats are descending then we have calculated angle relative to the - # south direction. Need to shift so it is relative to the north direction - if lat_lon[-1, 0, 0] > lat_lon[0, 0, 0]: - wd = (wd + 180) % 360 - + if invert_lat: + ws = ws[::-1] + wd = wd[::-1] return ws, wd diff --git a/tests/test_bias_correction.py b/tests/test_bias_correction.py index 3de4efe02..ce04ccf38 100644 --- a/tests/test_bias_correction.py +++ b/tests/test_bias_correction.py @@ -258,8 +258,8 @@ def test_fwp_integration(): model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=0, temporal_pad=0, - target=target, shape=shape, - temporal_slice=temporal_slice, + input_handler_kwargs=dict(target=target, shape=shape, + temporal_slice=temporal_slice), out_pattern=os.path.join(td, 'out_{file_id}.nc'), max_workers=1, input_handler='DataHandlerNCforCC') @@ -268,8 +268,8 @@ def test_fwp_integration(): model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=0, temporal_pad=0, - target=target, shape=shape, - temporal_slice=temporal_slice, + input_handler_kwargs=dict(target=target, shape=shape, + temporal_slice=temporal_slice), out_pattern=os.path.join(td, 'out_{file_id}.nc'), max_workers=1, input_handler='DataHandlerNCforCC', diff --git a/tests/test_cli.py b/tests/test_cli.py index 2ad351dd8..1dcc271ff 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -73,8 +73,8 @@ def test_pipeline_fwp_collect(runner): 'shape': shape, 'fwp_chunk_shape': fwp_chunk_shape, 'time_chunk_size': 10, - 'max_workers': 2, - 'pass_workers': 2, + 'max_workers': 1, + 'pass_workers': 1, 'spatial_pad': 5, 'temporal_pad': 5, 'overwrite_cache': True, @@ -236,18 +236,18 @@ def test_fwd_pass_cli(runner): log_prefix = os.path.join(td, 'log.log') config = {'ti_workers': 1, 'file_paths': input_files, - 'target': (19.3, -123.5), 'model_kwargs': {'model_dir': out_dir}, 'out_pattern': out_files, - 'cache_pattern': cache_pattern, 'log_pattern': log_prefix, - 'shape': shape, + 'input_handler_kwargs': {'target': (19.3, -123.5), + 'cache_pattern': cache_pattern, + 'shape': shape, + 'overwrite_cache': False, + 'time_chunk_size': 10}, 'fwp_chunk_shape': fwp_chunk_shape, - 'time_chunk_size': 10, 'max_workers': 1, 'spatial_pad': 5, 'temporal_pad': 5, - 'overwrite_cache': False, 'execution_control': { "option": "local"}} @@ -322,18 +322,18 @@ def test_pipeline_fwp_qa(runner): model.save(out_dir) fwp_config = {'file_paths': input_files, - 'target': (19.3, -123.5), 'model_kwargs': {'model_dir': out_dir}, 'out_pattern': os.path.join(td, 'out_{file_id}.h5'), 'log_pattern': os.path.join(td, 'fwp_log.log'), 'log_level': 'DEBUG', - 'shape': (8, 8), + 'input_handler_kwargs': {'target': (19.3, -123.5), + 'shape': (8, 8), + 'time_chunk_size': 10, + 'overwrite_cache': False}, 'fwp_chunk_shape': (100, 100, 100), - 'time_chunk_size': 10, 'max_workers': 1, 'spatial_pad': 5, 'temporal_pad': 5, - 'overwrite_cache': False, 'execution_control': { "option": "local"}} diff --git a/tests/test_forward_pass.py b/tests/test_forward_pass.py index 5efc11e97..8a59de65e 100644 --- a/tests/test_forward_pass.py +++ b/tests/test_forward_pass.py @@ -61,14 +61,15 @@ def test_fwp_nc_cc(): out_files = os.path.join(td, 'out_{file_id}.nc') # 1st forward pass max_workers = 1 + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, out_pattern=out_files, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=max_workers, input_handler='DataHandlerNCforCC') forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -111,14 +112,15 @@ def test_fwp_nc(): out_files = os.path.join(td, 'out_{file_id}.nc') max_workers = 1 + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, out_pattern=out_files, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=max_workers) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -165,14 +167,15 @@ def test_fwp_temporal_slice(): temporal_slice = slice(5, 17, 3) raw_time_index = np.arange(20) n_tsteps = len(raw_time_index[temporal_slice]) + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, out_pattern=out_files, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=max_workers) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -235,14 +238,15 @@ def test_fwp_handler(): max_workers = 1 cache_pattern = os.path.join(td, 'cache') + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, + input_handler_kwargs=input_handler_kwargs, max_workers=max_workers) forward_pass = ForwardPass(handler) assert forward_pass.data_handler.compute_workers == max_workers @@ -300,13 +304,11 @@ def test_fwp_chunking(log=False, plot=False): input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_shape, spatial_pad=spatial_pad, temporal_pad=temporal_pad, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, - ti_workers=1, + input_handler_kwargs=dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True, ti_workers=1), max_workers=1) - data_chunked = np.zeros((shape[0] * s_enhance, shape[1] * s_enhance, len(input_files) * t_enhance, len(model.output_features))) @@ -393,15 +395,15 @@ def test_fwp_nochunking(): model.save(out_dir) cache_pattern = os.path.join(td, 'cache') + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=(shape[0], shape[1], list_chunk_size), spatial_pad=0, temporal_pad=0, - target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, - ti_workers=1, + input_handler_kwargs=input_handler_kwargs, max_workers=1) forward_pass = ForwardPass(handler) data_chunked = forward_pass.run_chunk() @@ -484,14 +486,16 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): 'temporal_model_dirs': st_out_dir} out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, + input_handler_kwargs=input_handler_kwargs, spatial_pad=0, temporal_pad=0, - target=target, shape=shape, out_pattern=out_files, - temporal_slice=temporal_slice, max_workers=max_workers, exo_kwargs=exo_kwargs, max_nodes=1) @@ -586,14 +590,16 @@ def test_fwp_multi_step_model_topo_noskip(): 'temporal_model_dirs': st_out_dir} out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, - target=target, shape=shape, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, - temporal_slice=temporal_slice, max_workers=max_workers, exo_kwargs=exo_kwargs, max_nodes=1) @@ -666,17 +672,18 @@ def test_fwp_multi_step_model(): model_kwargs = {'spatial_model_dirs': s_out_dir, 'temporal_model_dirs': st_out_dir} + input_handler_kwargs = dict(target=target, shape=shape, + temporal_slice=temporal_slice, + overwrite_cache=True) handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=0, temporal_pad=0, - target=target, shape=shape, - temporal_slice=temporal_slice, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=max_workers, - max_nodes=1, - ti_workers=1) + max_nodes=1) forward_pass = ForwardPass(handler) ones = np.ones((fwp_chunk_shape[2], fwp_chunk_shape[0], @@ -742,13 +749,14 @@ def test_slicing_no_pad(log=False): sample_shape=(1, 1, 1), val_split=0.0, max_workers=1) + input_handler_kwargs = dict(target=target, shape=shape, + overwrite_cache=True) strategy = ForwardPassStrategy( input_files, model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(3, 2, 4), spatial_pad=0, temporal_pad=0, - target=target, shape=shape, - temporal_slice=slice(None), + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=1, max_nodes=1) @@ -757,7 +765,7 @@ def test_slicing_no_pad(log=False): forward_pass = ForwardPass(strategy, chunk_index=ichunk) s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass._ti_pad_slice, + forward_pass.ti_pad_slice, slice(None)) truth = handler.data[lr_data_slice] @@ -796,13 +804,14 @@ def test_slicing_pad(log=False): sample_shape=(1, 1, 1), val_split=0.0, max_workers=1) + input_handler_kwargs = dict(target=target, shape=shape, + overwrite_cache=True) strategy = ForwardPassStrategy( input_files, model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(2, 1, 4), + input_handler_kwargs=input_handler_kwargs, spatial_pad=2, temporal_pad=2, - target=target, shape=shape, - temporal_slice=slice(None), out_pattern=out_files, max_workers=1, max_nodes=1) @@ -826,7 +835,7 @@ def test_slicing_pad(log=False): s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass._ti_pad_slice, + forward_pass.ti_pad_slice, slice(None)) # do a manual calculation of what the padding should be. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7297d438e..1da9b7179 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -17,31 +17,6 @@ FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] -def test_config_gen(): - """Test configuration generation for forward pass and collect""" - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - - Sup3rGan.seed() - model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - _ = model.generate(np.ones((4, 10, 10, 6, 2))) - model.meta['training_features'] = ['U_100m', 'V_100m'] - model.meta['output_features'] = ['U_100m', 'V_100m'] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - out_dir = os.path.join(td, 'st_gan') - model.save(out_dir) - fp_config = os.path.join(td, 'config_fwp.json') - dc_config = os.path.join(td, 'config_collect.json') - pipe_config = os.path.join(td, 'config_pipeline.json') - Pipeline.init_pass_collect(td, input_files, out_dir) - assert os.path.exists(fp_config) - assert os.path.exists(dc_config) - assert os.path.exists(pipe_config) - - def test_fwp_pipeline(): """Test sup3r pipeline""" @@ -70,17 +45,20 @@ def test_fwp_pipeline(): out_files = os.path.join(td, 'fp_out_{file_id}.h5') log_prefix = os.path.join(td, 'log') t_enhance = 4 + + input_handler_kwargs = dict(target=target, shape=shape, + overwrite_cache=True, + time_chunk_size=10, + temporal_slice=[t_slice.start, + t_slice.stop]) config = {'max_workers': 1, 'file_paths': input_files, - 'target': target, 'model_kwargs': {'model_dir': out_dir}, 'out_pattern': out_files, 'cache_pattern': cache_pattern, 'log_pattern': log_prefix, - 'shape': shape, 'fwp_chunk_shape': fp_chunk_shape, - 'time_chunk_size': 10, - 'temporal_slice': [t_slice.start, t_slice.stop], + 'input_handler_kwargs': input_handler_kwargs, 'spatial_pad': 2, 'temporal_pad': 2, 'overwrite_cache': True, diff --git a/tests/test_qa.py b/tests/test_qa.py index d75c292d6..1bc9b6c6d 100644 --- a/tests/test_qa.py +++ b/tests/test_qa.py @@ -6,12 +6,14 @@ import numpy as np from rex import Resource import xarray as xr +import pickle from sup3r import TEST_DATA_DIR, CONFIG_DIR from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.models import Sup3rGan from sup3r.utilities.pytest_utils import make_fake_nc_files from sup3r.qa.qa import Sup3rQa +from sup3r.qa.stats import Sup3rWindStats FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') @@ -51,8 +53,8 @@ def test_qa_nc(): input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=FWP_CHUNK_SHAPE, spatial_pad=1, temporal_pad=1, - target=TARGET, shape=SHAPE, - temporal_slice=TEMPORAL_SLICE, + input_handler_kwargs=dict(target=TARGET, shape=SHAPE, + temporal_slice=TEMPORAL_SLICE), out_pattern=out_files, max_workers=1, max_nodes=1) @@ -129,12 +131,13 @@ def test_qa_h5(): model.save(out_dir) out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=TARGET, shape=SHAPE, + temporal_slice=TEMPORAL_SLICE) strategy = ForwardPassStrategy( input_files, model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=FWP_CHUNK_SHAPE, spatial_pad=1, temporal_pad=1, - target=TARGET, shape=SHAPE, - temporal_slice=TEMPORAL_SLICE, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, max_workers=1, max_nodes=1) @@ -193,3 +196,55 @@ def test_qa_h5(): assert np.allclose(qa_true, wtk_source, atol=0.01) assert np.allclose(qa_syn, fwp_data, atol=0.01) assert np.allclose(test_diff, qa_diff, atol=0.01) + + +def test_stats(): + """Test the WindStats module with forward pass output to h5 file.""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + _ = model.generate(np.ones((4, 10, 10, 6, len(TRAIN_FEATURES)))) + model.meta['training_features'] = TRAIN_FEATURES + model.meta['output_features'] = MODEL_OUT_FEATURES + model.meta['s_enhance'] = 3 + model.meta['t_enhance'] = 4 + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + out_dir = os.path.join(td, 'st_gan') + model.save(out_dir) + + out_files = os.path.join(td, 'out_{file_id}.h5') + strategy = ForwardPassStrategy( + input_files, model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=(100, 100, 100), + spatial_pad=1, temporal_pad=1, + input_handler_kwargs=dict(temporal_slice=TEMPORAL_SLICE), + out_pattern=out_files, + max_workers=1, + max_nodes=1) + + forward_pass = ForwardPass(strategy) + forward_pass.run_chunk() + + qa_fp = os.path.join(td, 'stats.pkl') + args = [input_files, strategy.out_files[0]] + kwargs = dict(heights=[100], s_enhance=S_ENHANCE, t_enhance=T_ENHANCE, + temporal_slice=TEMPORAL_SLICE, + qa_fp=qa_fp, include_stats=['ramp_rate', + 'velocity_grad', + 'tke_avg_k'], + max_workers=1, ramp_rate_t_step=1) + with Sup3rWindStats(*args, **kwargs) as qa: + qa.run() + assert os.path.exists(qa_fp) + with open(qa_fp, 'rb') as fh: + qa_out = pickle.load(fh) + assert 'lr_100m' in qa_out + assert 'hr_100m' in qa_out + for key in qa_out: + assert 'ramp_rate_1' in qa_out[key] + assert 'velocity_grad' in qa_out[key] + assert 'tke_avg_k' in qa_out[key]