Skip to content

Commit

Permalink
Merge pull request #98 from NREL/bnb/dev
Browse files Browse the repository at this point in the history
Bnb/dev
  • Loading branch information
bnb32 committed Oct 6, 2022
2 parents c1b023e + a8d3995 commit ec5f8b5
Show file tree
Hide file tree
Showing 13 changed files with 348 additions and 379 deletions.
302 changes: 155 additions & 147 deletions sup3r/pipeline/forward_pass.py

Large diffs are not rendered by default.

83 changes: 1 addition & 82 deletions sup3r/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
Sup3r data pipeline architecture.
"""
import logging
import os
import json

from reV.pipeline.pipeline import Pipeline
from rex.utilities.loggers import init_logger, create_dirs
from rex.utilities.loggers import init_logger

from sup3r.pipeline.config import Sup3rPipelineConfig
from sup3r.utilities import ModuleName
from sup3r.models.base import Sup3rGan
from sup3r.postprocessing.file_handling import OutputHandlerH5

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,80 +44,3 @@ def __init__(self, pipeline, monitor=True, verbose=False):
if 'logging' in self._config:
init_logger('sup3r.pipeline', **self._config.logging)
init_logger('reV.pipeline', **self._config.logging)

@classmethod
def init_pass_collect(cls, out_dir, file_paths, model_path,
fwp_kwargs=None, dc_kwargs=None):
"""Generate config files for forward pass and collection
Parameters
----------
out_dir : str
Parent directory for pipeline run.
file_paths : str | list
A single source h5 wind file to extract raster data from or a list
of netcdf files with identical grid. The string can be a unix-style
file path which will be passed through glob.glob
model_path : str
Path to gan model for forward pass
fwp_kwargs : dict
Dictionary of keyword args passed to the ForwardPassStrategy class.
dc_kwargs : dict
Dictionary of keyword args passed to the Collection.collect()
method.
"""
fwp_kwargs = fwp_kwargs or {}
dc_kwargs = dc_kwargs or {}
logger.info('Generating config files for forward pass and data '
'collection')
log_dir = os.path.join(out_dir, 'logs/')
output_dir = os.path.join(out_dir, 'output/')
cache_dir = os.path.join(out_dir, 'cache/')
std_out_dir = os.path.join(out_dir, 'stdout/')
all_dirs = [out_dir, log_dir, cache_dir, output_dir, std_out_dir]
for d in all_dirs:
create_dirs(d)
out_pattern = os.path.join(output_dir, 'fwp_out_{file_id}.h5')
cache_pattern = os.path.join(cache_dir, 'cache_{feature}.pkl')
log_pattern = os.path.join(log_dir, 'log_{node_index}.log')
model_params = Sup3rGan.load_saved_params(model_path,
verbose=False)['meta']
features = model_params['output_features']
features = OutputHandlerH5.get_renamed_features(features)
fwp_config = {'file_paths': file_paths,
'model_args': model_path,
'out_pattern': out_pattern,
'cache_pattern': cache_pattern,
'log_pattern': log_pattern}
fwp_config.update(fwp_kwargs)
fwp_config_file = os.path.join(out_dir, 'config_fwp.json')
with open(fwp_config_file, 'w') as f:
json.dump(fwp_config, f, sort_keys=True, indent=4)
logger.info(f'Saved forward-pass config file: {fwp_config_file}.')

collect_file = os.path.join(output_dir, 'out_collection.h5')
log_file = os.path.join(log_dir, 'collect.log')
input_files = os.path.join(output_dir, 'fwp_out_*.h5')
col_config = {'file_paths': input_files,
'out_file': collect_file,
'features': features,
'log_file': log_file}
col_config.update(dc_kwargs)
col_config_file = os.path.join(out_dir, 'config_collect.json')
with open(col_config_file, 'w') as f:
json.dump(col_config, f, sort_keys=True, indent=4)
logger.info(f'Saved data-collect config file: {col_config_file}.')

pipe_config = {'logging': {'log_level': 'DEBUG'},
'pipeline': [{'forward-pass': fwp_config_file},
{'data-collect': col_config_file}]}
pipeline_file = os.path.join(out_dir, 'config_pipeline.json')
with open(pipeline_file, 'w') as f:
json.dump(pipe_config, f, sort_keys=True, indent=4)
logger.info(f'Saved pipeline config file: {pipeline_file}.')

script_file = os.path.join(out_dir, 'run.sh')
with open(script_file, 'w') as f:
cmd = 'python -m sup3r.cli -c config_pipeline.json pipeline'
f.write(cmd)
logger.info(f'Saved script file: {script_file}.')
77 changes: 38 additions & 39 deletions sup3r/postprocessing/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,7 @@ def _get_collection_attrs(cls, file_paths, feature, sort=True,
----------
file_paths : list | str
Explicit list of str file paths that will be sorted and collected
or a single string with unix-style /search/patt*ern.h5. Files
should have non-overlapping time_index dataset and fully
overlapping meta dataset.
or a single string with unix-style /search/patt*ern.h5.
feature : str
Dataset name to collect.
sort : bool
Expand All @@ -341,7 +339,8 @@ def _get_collection_attrs(cls, file_paths, feature, sort=True,
None will use all available workers.
target_final_meta_file : str
Path to target final meta containing coordinates to keep from the
full file list collected meta
full list of coordinates present in the collected meta for the full
file list.
Returns
-------
Expand Down Expand Up @@ -466,19 +465,22 @@ def _ensure_dset_in_output(out_file, dset, data=None):
attrs=attrs, data=data,
chunks=attrs.get('chunks', None))

def _collect_flist(self, feature, masked_meta, time_index, shape,
file_paths, out_file, target_final_meta,
masked_target_meta, max_workers=None):
def _collect_flist(self, feature, subset_masked_meta, time_index, shape,
file_paths, out_file, target_masked_meta,
max_workers=None):
"""Collect a dataset from a file list without getting attributes first.
This file list can be a subset of a full file list to be collected.
Parameters
----------
feature : str
Dataset name to collect.
masked_meta : pd.DataFrame
Concatenated meta data for the given file paths. This masked
against the target_final_meta.
subset_masked_meta : pd.DataFrame
Meta data containing the list of coordinates present in both the
given file paths and the target_final_meta. This can be a subset of
the coordinates present in the full file list. The coordinates
contained in this dataframe have the same gids as those present in
the meta for the full file list.
time_index : pd.datetimeindex
Concatenated datetime index for the given file paths.
shape : tuple
Expand All @@ -488,17 +490,14 @@ def _collect_flist(self, feature, masked_meta, time_index, shape,
to be collected.
out_file : str
File path of final output file.
target_final_meta : str
Target final meta containing coordinates to keep from the
full file list collected meta
masked_target_meta : pd.DataFrame
Collected meta data with mask applied from target_final_meta so
original gids are preserved.
target_masked_meta : pd.DataFrame
Same as subset_masked_meta but instead for the entire list of files
to be collected.
max_workers : int | None
Number of workers to use in parallel. 1 runs serial,
None uses all available.
"""
if len(masked_meta) > 0:
if len(subset_masked_meta) > 0:
attrs, final_dtype = get_dset_attrs(feature)
scale_factor = attrs.get('scale_factor', 1)

Expand All @@ -517,8 +516,9 @@ def _collect_flist(self, feature, masked_meta, time_index, shape,
for i, fname in enumerate(file_paths):
logger.debug('Collecting data from file {} out of {}.'
.format(i + 1, len(file_paths)))
self.get_data(fname, feature, time_index, masked_meta,
scale_factor, final_dtype)
self.get_data(fname, feature, time_index,
subset_masked_meta, scale_factor,
final_dtype)
else:
logger.info('Running parallel collection on {} workers.'
.format(max_workers))
Expand All @@ -528,7 +528,7 @@ def _collect_flist(self, feature, masked_meta, time_index, shape,
with ThreadPoolExecutor(max_workers=max_workers) as exe:
for fname in file_paths:
future = exe.submit(self.get_data, fname, feature,
time_index, masked_meta,
time_index, subset_masked_meta,
scale_factor, final_dtype)
futures[future] = fname
for future in as_completed(futures):
Expand All @@ -547,15 +547,11 @@ def _collect_flist(self, feature, masked_meta, time_index, shape,
msg += f'{futures[future]}'
logger.exception(msg)
raise RuntimeError(msg) from e
if not os.path.exists(out_file):
Collector._init_collected_h5(out_file, time_index,
target_final_meta)
x_write_slice, y_write_slice = slice(None), slice(None)
else:
with RexOutputs(out_file, 'r') as f:
target_ti = f.time_index
y_write_slice, x_write_slice = Collector.get_slices(
target_ti, masked_target_meta, time_index, masked_meta)
with RexOutputs(out_file, 'r') as f:
target_ti = f.time_index
y_write_slice, x_write_slice = Collector.get_slices(
target_ti, target_masked_meta, time_index,
subset_masked_meta)
Collector._ensure_dset_in_output(out_file, feature)
with RexOutputs(out_file, mode='a') as f:
f[feature, y_write_slice, x_write_slice] = self.data
Expand Down Expand Up @@ -641,9 +637,7 @@ def collect(cls, file_paths, out_file, features, max_workers=None,
----------
file_paths : list | str
Explicit list of str file paths that will be sorted and collected
or a single string with unix-style /search/patt*ern.h5. Files
should have non-overlapping time_index dataset and fully
overlapping meta dataset.
or a single string with unix-style /search/patt*ern.h5.
out_file : str
File path of final output file.
features : list
Expand All @@ -666,7 +660,13 @@ def collect(cls, file_paths, out_file, features, max_workers=None,
a suffix format _{temporal_chunk_index}_{spatial_chunk_index}.h5
target_final_meta_file : str
Path to target final meta containing coordinates to keep from the
full file list collected meta
full file list collected meta. This can be but is not necessarily a
subset of the full list of coordinates for all files in the file
list. This is used to remove coordinates from the full file list
which are not present in the target_final_meta. Either this full
meta or a subset, depending on which coordinates are present in
the data to be collected, will be the final meta for the collected
output files.
n_writes : int | None
Number of writes to split full file list into. Must be less than
or equal to the number of temporal chunks.
Expand Down Expand Up @@ -696,17 +696,17 @@ def collect(cls, file_paths, out_file, features, max_workers=None,
out = collector._get_collection_attrs(
collector.flist, dset, max_workers=max_workers,
target_final_meta_file=target_final_meta_file)
time_index, final_target_meta, masked_target_meta = out[:3]
time_index, target_final_meta, target_masked_meta = out[:3]
shape, _, global_attrs = out[3:]

if not os.path.exists(out_file):
collector._init_collected_h5(out_file, time_index,
final_target_meta, global_attrs)
target_final_meta, global_attrs)

if len(flist_chunks) == 1:
collector._collect_flist(dset, masked_target_meta, time_index,
collector._collect_flist(dset, target_masked_meta, time_index,
shape, flist_chunks[0], out_file,
final_target_meta, masked_target_meta,
target_masked_meta,
max_workers=max_workers)

else:
Expand All @@ -719,8 +719,7 @@ def collect(cls, file_paths, out_file, features, max_workers=None,
target_final_meta_file=target_final_meta_file)
collector._collect_flist(dset, masked_meta, time_index,
shape, flist, out_file,
target_final_meta,
masked_target_meta,
target_masked_meta,
max_workers=max_workers)

if write_status and job_name is not None:
Expand Down
2 changes: 1 addition & 1 deletion sup3r/postprocessing/file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,5 +532,5 @@ def _write_output(cls, data, features, lat_lon, times, out_file,

if meta_data is not None:
fh.run_attrs = {'gan_meta': json.dumps(meta_data)}
os.rename(tmp_file, out_file)
os.replace(tmp_file, out_file)
logger.info(f'Saved output of size {data.shape} to: {out_file}')
2 changes: 1 addition & 1 deletion sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1285,7 +1285,7 @@ def cache_data(self, cache_file_paths):
tmp_file = fp.replace('.pkl', '.pkl.tmp')
with open(tmp_file, 'wb') as fh:
pickle.dump(self.data[..., i], fh, protocol=4)
os.rename(tmp_file, fp)
os.replace(tmp_file, fp)
else:
msg = (f'Called cache_data but {fp} already exists. Set to '
'overwrite_cache to True to overwrite.')
Expand Down
20 changes: 10 additions & 10 deletions sup3r/qa/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance,
get_lr : bool
Whether to include low resolution stats in output
include_stats : list | None
List of stats to include in output. e.g. ['ws_ramp_rate',
List of stats to include in output. e.g. ['ramp_rate',
'velocity_grad', 'vorticity', 'tke_avg_k', 'tke_avg_f', 'tke_ts']
max_values : dict | None
Dictionary of max values to keep for stats. e.g.
{'ws_ramp_rate_max': 10, 'v_grad_max': 14, 'vorticity_max': 7}
{'ramp_rate_max': 10, 'v_grad_max': 14, 'vorticity_max': 7}
ramp_rate_t_step : int | list
Number of time steps to use for ramp rate calculation. If low res
data is hourly then t_step=1 will calculate the hourly ramp rate.
Expand All @@ -168,16 +168,16 @@ def __init__(self, source_file_paths, out_file_path, s_enhance, t_enhance,
ti_workers = max_workers

self.max_values = max_values or {}
self.ws_ramp_rate_max = self.max_values.get('ws_ramp_rate_max', 10)
self.ramp_rate_max = self.max_values.get('ramp_rate_max', 10)
self.v_grad_max = self.max_values.get('v_grad_max', 7)
self.vorticity_max = self.max_values.get('vorticity_max', 14)
self.ramp_rate_t_step = (ramp_rate_t_step
if isinstance(ramp_rate_t_step, list)
else [ramp_rate_t_step])
self.include_stats = include_stats or ['ws_ramp_rate', 'velocity_grad',
self.include_stats = include_stats or ['ramp_rate', 'velocity_grad',
'vorticity', 'tke_avg_k',
'tke_avg_f', 'tke_ts',
'mean_ws_ramp_rate']
'mean_ramp_rate']

self.s_enhance = s_enhance
self.t_enhance = t_enhance
Expand Down Expand Up @@ -761,18 +761,18 @@ def get_ramp_rate_stats(self, u, v, scale=1):
"""

stats_dict = {}
if 'ws_ramp_rate' in self.include_stats:
if 'ramp_rate' in self.include_stats:
for i, time in enumerate(self.ramp_rate_t_step):
logger.info('Computing ramp rate pdf.')
out = ws_ramp_rate_dist(u, v, diff_max=self.ws_ramp_rate_max,
out = ws_ramp_rate_dist(u, v, diff_max=self.ramp_rate_max,
t_steps=time, scale=scale)
stats_dict[f'ramp_rate_{self.ramp_rate_t_step[i]}'] = out
if 'mean_ws_ramp_rate' in self.include_stats:
if 'mean_ramp_rate' in self.include_stats:
for i, time in enumerate(self.ramp_rate_t_step):
logger.info('Computing mean ramp rate pdf.')
out = ws_ramp_rate_dist(np.mean(u, axis=(0, 1)),
np.mean(v, axis=(0, 1)),
diff_max=self.ws_ramp_rate_max,
diff_max=self.ramp_rate_max,
t_steps=time, scale=scale)
stats_dict[f'mean_ramp_rate_{self.ramp_rate_t_step[i]}'] = out
return stats_dict
Expand Down Expand Up @@ -879,7 +879,7 @@ def run(self):
stats : dict
Dictionary of statistics, where keys are lr/hr/interp appended with
the height of the corresponding wind field. Values are dictionaries
of statistics, such as velocity_gradient, vorticity, ws_ramp_rate,
of statistics, such as velocity_gradient, vorticity, ramp_rate,
etc
"""

Expand Down
2 changes: 1 addition & 1 deletion sup3r/qa/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def ws_ramp_rate_dist(u, v, bins=50, range=None, diff_max=10, t_steps=1,
msg = (f'Received t_steps={t_steps} for ramp rate calculation but data '
f'only has {u.shape[-1]} time steps')
assert t_steps < u.shape[-1], msg
ws = np.sqrt(u**2 + v**2)
ws = np.hypot(u, v)
diffs = (ws[..., t_steps:] - ws[..., :-t_steps]).flatten()
diffs /= scale
diffs = diffs[(np.abs(diffs) < diff_max)]
Expand Down
Loading

0 comments on commit ec5f8b5

Please sign in to comment.