Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/dev #98

Merged
merged 10 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cleaning up fwp strat args
  • Loading branch information
bnb32 committed Oct 5, 2022
commit fcba27d661c8695c1ece5e52f3f73f7c24f84c04
126 changes: 41 additions & 85 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,24 +541,13 @@ class ForwardPassStrategy(InputMixIn):

def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
spatial_pad, temporal_pad,
temporal_slice=slice(None),
model_class='Sup3rGan',
target=None, shape=None,
raster_file=None,
time_chunk_size=None,
cache_pattern=None,
out_pattern=None,
overwrite_cache=False,
overwrite_ti_cache=False,
input_handler=None,
input_handler_kwargs=None,
incremental=True,
max_workers=None,
extract_workers=None,
compute_workers=None,
load_workers=None,
output_workers=None,
ti_workers=None,
exo_kwargs=None,
pass_workers=1,
bias_correct_method=None,
Expand Down Expand Up @@ -601,47 +590,10 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
passes for subsequent temporal stitching. This overlap will pad
both sides of the fwp_chunk_shape. Note that the first and last
chunks in the temporal dimension will not be padded.
temporal_slice : slice | tuple | list
Slice defining size of full temporal domain. e.g. If we have 5
files each with 5 time steps then temporal_slice = slice(None) will
select all 25 time steps. This can also be a tuple / list with
length 3 that will be interpreted as slice(*temporal_slice)
model_class : str
Name of the sup3r model class for the GAN model to load. The
default is the basic spatial / spatiotemporal Sup3rGan model. This
will be loaded from sup3r.models
target : tuple
(lat, lon) lower left corner of raster. You should provide
target+shape or raster_file, or if all three are None the full
source domain will be used.
shape : tuple
(rows, cols) grid size. You should provide target+shape or
raster_file, or if all three are None the full source domain will
be used.
raster_file : str | None
File for raster_index array for the corresponding target and
shape. If specified the raster_index will be loaded from the file
if it exists or written to the file if it does not yet exist.
If None raster_index will be calculated directly. You should
provide target+shape or raster_file, or if all three are None the
full source domain will be used.
time_chunk_size : int
Size of chunks to split time dimension into for parallel data
extraction. If running in serial this can be set to the size
of the full time index for best performance.
cache_pattern : str | None
Pattern for files for saving feature data. e.g.
file_path_{feature}.pkl Each feature will be saved to a file with
the feature name replaced in cache_pattern. If not None
feature arrays will be saved here and not stored in self.data until
load_cached_data is called. The cache_pattern can also include
{shape}, {target}, {times} which will help ensure unique cache
files for complex problems.
overwrite_cache : bool
Whether to overwrite cache files storing the computed/extracted
feature data
overwrite_ti_cache : bool
Whether to overwrite time index cache files
out_pattern : str
Output file pattern. Must be of form <path>/<name>_{file_id}.<ext>.
e.g. /tmp/sup3r_job_{file_id}.h5
Expand All @@ -654,9 +606,10 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
match a class in data_handling.py. If None the correct handler will
be guessed based on file type and time series properties.
input_handler_kwargs : dict | None
Optional kwargs for initializing the input_handler class. For
example, this could be {'hr_spatial_coarsen': 2} if you wanted to
artificially coarsen the input data for testing.
kwargs for initializing the input_handler class
:class:`sup3r.preprocessing.data_handling.DataHandler`. These
kwargs include temporal_slice, target, shape used to define the
spatiotemporal domain sent through the forward passes.
incremental : bool
Allow the forward pass iteration to skip spatiotemporal chunks that
already have an output file (True, default) or iterate through all
Expand All @@ -667,38 +620,24 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
If max_workers == 1 then all processes will be serialized. If None
extract_workers, compute_workers, load_workers, output_workers will
use their own provided values.
extract_workers : int | None
max number of workers to use for extracting features from source
data.
compute_workers : int | None
max number of workers to use for computing derived features from
raw features in source data.
load_workers : int | None
max number of workers to use for loading cached feature data.
output_workers : int | None
max number of workers to use for writing forward pass output.
pass_workers : int | None
max number of workers to use for performing forward passes on a
single node. If 1 then all forward passes on chunks distributed to
a single node will be run in serial.
ti_workers : int | None
max number of workers to use to get full time index. Useful when
there are many input files each with a single time step. If this is
greater than one, time indices for input files will be extracted in
parallel and then concatenated to get the full time index. If input
files do not all have time indices or if there are few input files
this should be set to one.
exo_kwargs : dict | None
Dictionary of args to pass to ExogenousDataHandler for extracting
exogenous features such as topography for future multistep foward
pass
Dictionary of args to pass to :class:
`sup3r.preprocessing.exogeneous_data_handling.ExogenousDataHandler`
for extracting exogenous features such as topography for future
multistep foward pass
bias_correct_method : str | None
Optional bias correction function name that can be imported from
the sup3r.bias.bias_transforms module. This will transform the
source data according to some predefined bias correction
the :meth:`sup3r.bias.bias_transforms` module. This will transform
bnb32 marked this conversation as resolved.
Show resolved Hide resolved
the source data according to some predefined bias correction
transformation along with the bias_correct_kwargs. As the first
argument, this method must receive a generic numpy array of data to
be bias corrected
argument, this method must receive a generic numpy array of data
to be bias corrected
bias_correct_kwargs : dict | None
Optional namespace of kwargs to provide to bias_correct_method.
If this is provided, it must be a dictionary where each key is a
Expand All @@ -717,28 +656,16 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
self.temporal_pad = temporal_pad
self.model_class = model_class
self.out_pattern = out_pattern
self.raster_file = raster_file
self.temporal_slice = temporal_slice
self.time_chunk_size = time_chunk_size
self.overwrite_cache = overwrite_cache
self.overwrite_ti_cache = overwrite_ti_cache
self.max_workers = max_workers
self.extract_workers = extract_workers
self.compute_workers = compute_workers
self.load_workers = load_workers
self.output_workers = output_workers
self.pass_workers = pass_workers
self.exo_kwargs = exo_kwargs or {}
self.incremental = incremental
self.ti_workers = ti_workers
self._single_time_step_files = None
self._cache_pattern = cache_pattern
self._input_handler_class = None
self._input_handler_name = input_handler
self._max_nodes = max_nodes
self._input_handler_kwargs = input_handler_kwargs or {}
self._grid_shape = shape
self._target = target
self._time_index = None
self._raw_time_index = None
self._out_files = None
Expand All @@ -749,6 +676,7 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
self.bias_correct_method = bias_correct_method
self.bias_correct_kwargs = bias_correct_kwargs or {}

self.get_input_handler_kwargs(self._input_handler_kwargs)
self.cap_worker_args(max_workers)

model_class = getattr(sup3r.models, self.model_class, None)
Expand Down Expand Up @@ -794,6 +722,34 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,

self.preflight()

def get_input_handler_kwargs(self, input_handler_kwargs):
"""Get input handler args from input_handler_kwargs dict

Parameters
----------
input_handler_kwargs : dict
Dictionary of args to pass to the
:class:`sup3r.preprocessing.data_handling.DataHandler`
"""
self._target = input_handler_kwargs.get('target', None)
self._grid_shape = input_handler_kwargs.get('shape', None)
self.raster_file = input_handler_kwargs.get('raster_file', None)
self.temporal_slice = input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = input_handler_kwargs.get('time_chunk_size',
None)
self.overwrite_cache = input_handler_kwargs.get('overwrite_cache',
False)
self.overwrite_ti_cache = input_handler_kwargs.get(
'overwrite_ti_cache', False)
self.extract_workers = input_handler_kwargs.get('extract_workers',
None)
self.compute_workers = input_handler_kwargs.get('compute_workers',
None)
self.load_workers = input_handler_kwargs.get('load_workers', None)
self.ti_workers = input_handler_kwargs.get('ti_workers', None)
self._cache_pattern = input_handler_kwargs.get('cache_pattern', None)
bnb32 marked this conversation as resolved.
Show resolved Hide resolved

@property
def worker_attrs(self):
"""Get all worker args defined in init"""
Expand Down
83 changes: 1 addition & 82 deletions sup3r/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
Sup3r data pipeline architecture.
"""
import logging
import os
import json

from reV.pipeline.pipeline import Pipeline
from rex.utilities.loggers import init_logger, create_dirs
from rex.utilities.loggers import init_logger

from sup3r.pipeline.config import Sup3rPipelineConfig
from sup3r.utilities import ModuleName
from sup3r.models.base import Sup3rGan
from sup3r.postprocessing.file_handling import OutputHandlerH5

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,80 +44,3 @@ def __init__(self, pipeline, monitor=True, verbose=False):
if 'logging' in self._config:
init_logger('sup3r.pipeline', **self._config.logging)
init_logger('reV.pipeline', **self._config.logging)

@classmethod
def init_pass_collect(cls, out_dir, file_paths, model_path,
fwp_kwargs=None, dc_kwargs=None):
"""Generate config files for forward pass and collection
bnb32 marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
out_dir : str
Parent directory for pipeline run.
file_paths : str | list
A single source h5 wind file to extract raster data from or a list
of netcdf files with identical grid. The string can be a unix-style
file path which will be passed through glob.glob
model_path : str
Path to gan model for forward pass
fwp_kwargs : dict
Dictionary of keyword args passed to the ForwardPassStrategy class.
dc_kwargs : dict
Dictionary of keyword args passed to the Collection.collect()
method.
"""
fwp_kwargs = fwp_kwargs or {}
dc_kwargs = dc_kwargs or {}
logger.info('Generating config files for forward pass and data '
'collection')
log_dir = os.path.join(out_dir, 'logs/')
output_dir = os.path.join(out_dir, 'output/')
cache_dir = os.path.join(out_dir, 'cache/')
std_out_dir = os.path.join(out_dir, 'stdout/')
all_dirs = [out_dir, log_dir, cache_dir, output_dir, std_out_dir]
for d in all_dirs:
create_dirs(d)
out_pattern = os.path.join(output_dir, 'fwp_out_{file_id}.h5')
cache_pattern = os.path.join(cache_dir, 'cache_{feature}.pkl')
log_pattern = os.path.join(log_dir, 'log_{node_index}.log')
model_params = Sup3rGan.load_saved_params(model_path,
verbose=False)['meta']
features = model_params['output_features']
features = OutputHandlerH5.get_renamed_features(features)
fwp_config = {'file_paths': file_paths,
'model_args': model_path,
'out_pattern': out_pattern,
'cache_pattern': cache_pattern,
'log_pattern': log_pattern}
fwp_config.update(fwp_kwargs)
fwp_config_file = os.path.join(out_dir, 'config_fwp.json')
with open(fwp_config_file, 'w') as f:
json.dump(fwp_config, f, sort_keys=True, indent=4)
logger.info(f'Saved forward-pass config file: {fwp_config_file}.')

collect_file = os.path.join(output_dir, 'out_collection.h5')
log_file = os.path.join(log_dir, 'collect.log')
input_files = os.path.join(output_dir, 'fwp_out_*.h5')
col_config = {'file_paths': input_files,
'out_file': collect_file,
'features': features,
'log_file': log_file}
col_config.update(dc_kwargs)
col_config_file = os.path.join(out_dir, 'config_collect.json')
with open(col_config_file, 'w') as f:
json.dump(col_config, f, sort_keys=True, indent=4)
logger.info(f'Saved data-collect config file: {col_config_file}.')

pipe_config = {'logging': {'log_level': 'DEBUG'},
'pipeline': [{'forward-pass': fwp_config_file},
{'data-collect': col_config_file}]}
pipeline_file = os.path.join(out_dir, 'config_pipeline.json')
with open(pipeline_file, 'w') as f:
json.dump(pipe_config, f, sort_keys=True, indent=4)
logger.info(f'Saved pipeline config file: {pipeline_file}.')

script_file = os.path.join(out_dir, 'run.sh')
with open(script_file, 'w') as f:
cmd = 'python -m sup3r.cli -c config_pipeline.json pipeline'
f.write(cmd)
logger.info(f'Saved script file: {script_file}.')
18 changes: 9 additions & 9 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,18 @@ def test_fwd_pass_cli(runner):
log_prefix = os.path.join(td, 'log.log')
config = {'ti_workers': 1,
'file_paths': input_files,
'target': (19.3, -123.5),
'model_kwargs': {'model_dir': out_dir},
'out_pattern': out_files,
'cache_pattern': cache_pattern,
'log_pattern': log_prefix,
'shape': shape,
'input_handler_kwargs': {'target': (19.3, -123.5),
'cache_pattern': cache_pattern,
'shape': shape,
'overwrite_cache': False,
'time_chunk_size': 10},
'fwp_chunk_shape': fwp_chunk_shape,
'time_chunk_size': 10,
'max_workers': 1,
'spatial_pad': 5,
'temporal_pad': 5,
'overwrite_cache': False,
'execution_control': {
"option": "local"}}

Expand Down Expand Up @@ -322,18 +322,18 @@ def test_pipeline_fwp_qa(runner):
model.save(out_dir)

fwp_config = {'file_paths': input_files,
'target': (19.3, -123.5),
'model_kwargs': {'model_dir': out_dir},
'out_pattern': os.path.join(td, 'out_{file_id}.h5'),
'log_pattern': os.path.join(td, 'fwp_log.log'),
'log_level': 'DEBUG',
'shape': (8, 8),
'input_handler_kwargs': {'target': (19.3, -123.5),
'shape': (8, 8),
'time_chunk_size': 10,
'overwrite_cache': False},
'fwp_chunk_shape': (100, 100, 100),
'time_chunk_size': 10,
'max_workers': 1,
'spatial_pad': 5,
'temporal_pad': 5,
'overwrite_cache': False,
'execution_control': {
"option": "local"}}

Expand Down
Loading