Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/dev #98

Merged
merged 10 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
test fix
  • Loading branch information
bnb32 committed Oct 5, 2022
commit fa436b140de0716ae418210b6733bbf10a9067a4
113 changes: 85 additions & 28 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,10 +590,47 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
passes for subsequent temporal stitching. This overlap will pad
both sides of the fwp_chunk_shape. Note that the first and last
chunks in the temporal dimension will not be padded.
temporal_slice : slice | tuple | list
Slice defining size of full temporal domain. e.g. If we have 5
files each with 5 time steps then temporal_slice = slice(None) will
select all 25 time steps. This can also be a tuple / list with
length 3 that will be interpreted as slice(*temporal_slice)
model_class : str
Name of the sup3r model class for the GAN model to load. The
default is the basic spatial / spatiotemporal Sup3rGan model. This
will be loaded from sup3r.models
target : tuple
(lat, lon) lower left corner of raster. You should provide
target+shape or raster_file, or if all three are None the full
source domain will be used.
shape : tuple
(rows, cols) grid size. You should provide target+shape or
raster_file, or if all three are None the full source domain will
be used.
raster_file : str | None
File for raster_index array for the corresponding target and
shape. If specified the raster_index will be loaded from the file
if it exists or written to the file if it does not yet exist.
If None raster_index will be calculated directly. You should
provide target+shape or raster_file, or if all three are None the
full source domain will be used.
time_chunk_size : int
Size of chunks to split time dimension into for parallel data
extraction. If running in serial this can be set to the size
of the full time index for best performance.
cache_pattern : str | None
Pattern for files for saving feature data. e.g.
file_path_{feature}.pkl Each feature will be saved to a file with
the feature name replaced in cache_pattern. If not None
feature arrays will be saved here and not stored in self.data until
load_cached_data is called. The cache_pattern can also include
{shape}, {target}, {times} which will help ensure unique cache
files for complex problems.
overwrite_cache : bool
Whether to overwrite cache files storing the computed/extracted
feature data
overwrite_ti_cache : bool
Whether to overwrite time index cache files
out_pattern : str
Output file pattern. Must be of form <path>/<name>_{file_id}.<ext>.
e.g. /tmp/sup3r_job_{file_id}.h5
Expand All @@ -606,10 +643,9 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
match a class in data_handling.py. If None the correct handler will
be guessed based on file type and time series properties.
input_handler_kwargs : dict | None
kwargs for initializing the input_handler class
:class:`sup3r.preprocessing.data_handling.DataHandler`. These
kwargs include temporal_slice, target, shape used to define the
spatiotemporal domain sent through the forward passes.
Optional kwargs for initializing the input_handler class. For
example, this could be {'hr_spatial_coarsen': 2} if you wanted to
artificially coarsen the input data for testing.
incremental : bool
Allow the forward pass iteration to skip spatiotemporal chunks that
already have an output file (True, default) or iterate through all
Expand All @@ -620,24 +656,38 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
If max_workers == 1 then all processes will be serialized. If None
extract_workers, compute_workers, load_workers, output_workers will
use their own provided values.
extract_workers : int | None
max number of workers to use for extracting features from source
data.
compute_workers : int | None
max number of workers to use for computing derived features from
raw features in source data.
load_workers : int | None
max number of workers to use for loading cached feature data.
output_workers : int | None
max number of workers to use for writing forward pass output.
pass_workers : int | None
max number of workers to use for performing forward passes on a
single node. If 1 then all forward passes on chunks distributed to
a single node will be run in serial.
ti_workers : int | None
max number of workers to use to get full time index. Useful when
there are many input files each with a single time step. If this is
greater than one, time indices for input files will be extracted in
parallel and then concatenated to get the full time index. If input
files do not all have time indices or if there are few input files
this should be set to one.
exo_kwargs : dict | None
Dictionary of args to pass to :class:
`sup3r.preprocessing.exogeneous_data_handling.ExogenousDataHandler`
for extracting exogenous features such as topography for future
multistep foward pass
Dictionary of args to pass to ExogenousDataHandler for extracting
exogenous features such as topography for future multistep foward
pass
bias_correct_method : str | None
Optional bias correction function name that can be imported from
the :mod:`sup3r.bias.bias_transforms` module. This will transform
the source data according to some predefined bias correction
transformation along with the bias_correct_kwargs. As the first
argument, this method must receive a generic numpy array of data
to be bias corrected
argument, this method must receive a generic numpy array of data to
be bias corrected
bias_correct_kwargs : dict | None
Optional namespace of kwargs to provide to bias_correct_method.
If this is provided, it must be a dictionary where each key is a
Expand Down Expand Up @@ -676,25 +726,26 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
self.bias_correct_method = bias_correct_method
self.bias_correct_kwargs = bias_correct_kwargs or {}

self._target = input_handler_kwargs.get('target', None)
self._grid_shape = input_handler_kwargs.get('shape', None)
self.raster_file = input_handler_kwargs.get('raster_file', None)
self.temporal_slice = input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = input_handler_kwargs.get('time_chunk_size',
None)
self.overwrite_cache = input_handler_kwargs.get('overwrite_cache',
False)
self.overwrite_ti_cache = input_handler_kwargs.get(
self._target = self._input_handler_kwargs.get('target', None)
self._grid_shape = self._input_handler_kwargs.get('shape', None)
self.raster_file = self._input_handler_kwargs.get('raster_file', None)
self.temporal_slice = self._input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = self._input_handler_kwargs.get(
'time_chunk_size', None)
self.overwrite_cache = self._input_handler_kwargs.get(
'overwrite_cache', False)
self.overwrite_ti_cache = self._input_handler_kwargs.get(
'overwrite_ti_cache', False)
self.extract_workers = input_handler_kwargs.get('extract_workers',
None)
self.compute_workers = input_handler_kwargs.get('compute_workers',
None)
self.load_workers = input_handler_kwargs.get('load_workers', None)
self.ti_workers = input_handler_kwargs.get('ti_workers', None)
self._cache_pattern = input_handler_kwargs.get('cache_pattern', None)

self.extract_workers = self._input_handler_kwargs.get(
'extract_workers', None)
self.compute_workers = self._input_handler_kwargs.get(
'compute_workers', None)
self.load_workers = self._input_handler_kwargs.get('load_workers',
None)
self.ti_workers = self._input_handler_kwargs.get('ti_workers', None)
self._cache_pattern = self._input_handler_kwargs.get('cache_pattern',
None)
self.cap_worker_args(max_workers)

model_class = getattr(sup3r.models, self.model_class, None)
Expand Down Expand Up @@ -1079,7 +1130,13 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
temporal_slice=self.temporal_pad_slice,
raster_file=self.raster_file,
cache_pattern=self.cache_pattern,
time_chunk_size=self.strategy.time_chunk_size,
overwrite_cache=self.strategy.overwrite_cache,
max_workers=self.max_workers,
extract_workers=strategy.extract_workers,
compute_workers=strategy.compute_workers,
load_workers=strategy.load_workers,
ti_workers=strategy.ti_workers,
handle_features=self.strategy.handle_features,
val_split=0.0)
input_handler_kwargs.update(fwp_input_handler_kwargs)
Expand Down
11 changes: 5 additions & 6 deletions tests/test_forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,14 @@ def test_fwp_chunking(log=False, plot=False):
temporal_pad = 20
cache_pattern = os.path.join(td, 'cache')
fwp_shape = (4, 4, len(input_files) // 2)
input_handler_kwargs = dict(target=target, shape=shape,
temporal_slice=temporal_slice,
cache_pattern=cache_pattern,
overwrite_cache=True)
handler = ForwardPassStrategy(
input_files, model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_shape,
spatial_pad=1, temporal_pad=1,
input_handler_kwargs=input_handler_kwargs,
spatial_pad=spatial_pad, temporal_pad=temporal_pad,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
cache_pattern=cache_pattern,
overwrite_cache=True, ti_workers=1),
max_workers=1)
data_chunked = np.zeros((shape[0] * s_enhance, shape[1] * s_enhance,
len(input_files) * t_enhance,
Expand Down