Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/dev #98

Merged
merged 10 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Prev Previous commit
Next Next commit
refactored get_chunk_kwargs into methods and properties
  • Loading branch information
bnb32 committed Oct 5, 2022
commit 3c01caa2b5348e741c125a9affd1f2128714b54d
284 changes: 144 additions & 140 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
multistep foward pass
bias_correct_method : str | None
Optional bias correction function name that can be imported from
the :meth:`sup3r.bias.bias_transforms` module. This will transform
the :mod:`sup3r.bias.bias_transforms` module. This will transform
the source data according to some predefined bias correction
transformation along with the bias_correct_kwargs. As the first
argument, this method must receive a generic numpy array of data
Expand Down Expand Up @@ -676,7 +676,25 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,
self.bias_correct_method = bias_correct_method
self.bias_correct_kwargs = bias_correct_kwargs or {}

self.get_input_handler_kwargs(self._input_handler_kwargs)
self._target = input_handler_kwargs.get('target', None)
self._grid_shape = input_handler_kwargs.get('shape', None)
self.raster_file = input_handler_kwargs.get('raster_file', None)
self.temporal_slice = input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = input_handler_kwargs.get('time_chunk_size',
None)
self.overwrite_cache = input_handler_kwargs.get('overwrite_cache',
False)
self.overwrite_ti_cache = input_handler_kwargs.get(
'overwrite_ti_cache', False)
self.extract_workers = input_handler_kwargs.get('extract_workers',
None)
self.compute_workers = input_handler_kwargs.get('compute_workers',
None)
self.load_workers = input_handler_kwargs.get('load_workers', None)
self.ti_workers = input_handler_kwargs.get('ti_workers', None)
self._cache_pattern = input_handler_kwargs.get('cache_pattern', None)

self.cap_worker_args(max_workers)

model_class = getattr(sup3r.models, self.model_class, None)
Expand Down Expand Up @@ -722,34 +740,6 @@ def __init__(self, file_paths, model_kwargs, fwp_chunk_shape,

self.preflight()

def get_input_handler_kwargs(self, input_handler_kwargs):
"""Get input handler args from input_handler_kwargs dict

Parameters
----------
input_handler_kwargs : dict
Dictionary of args to pass to the data handler
:class:`sup3r.preprocessing.data_handling.DataHandler`
"""
self._target = input_handler_kwargs.get('target', None)
self._grid_shape = input_handler_kwargs.get('shape', None)
self.raster_file = input_handler_kwargs.get('raster_file', None)
self.temporal_slice = input_handler_kwargs.get('temporal_slice',
slice(None))
self.time_chunk_size = input_handler_kwargs.get('time_chunk_size',
None)
self.overwrite_cache = input_handler_kwargs.get('overwrite_cache',
False)
self.overwrite_ti_cache = input_handler_kwargs.get(
'overwrite_ti_cache', False)
self.extract_workers = input_handler_kwargs.get('extract_workers',
None)
self.compute_workers = input_handler_kwargs.get('compute_workers',
None)
self.load_workers = input_handler_kwargs.get('load_workers', None)
self.ti_workers = input_handler_kwargs.get('ti_workers', None)
self._cache_pattern = input_handler_kwargs.get('cache_pattern', None)

@property
def worker_attrs(self):
"""Get all worker args defined in init"""
Expand Down Expand Up @@ -1044,7 +1034,12 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
self.output_features = self.model.output_features
self.meta_data = self.model.meta

self.get_chunk_kwargs(strategy, chunk_index)
self._file_paths = strategy.file_paths
self.max_workers = strategy.max_workers
self.pass_workers = strategy.pass_workers
self.output_workers = strategy.output_workers
self.exo_kwargs = strategy.exo_kwargs
self.single_time_step_files = strategy.single_time_step_files

self.exogenous_handler = None
self.exogenous_data = None
Expand Down Expand Up @@ -1095,10 +1090,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
self.input_data = self.bias_correct_source_data(self.input_data)

exo_s_en = self.exo_kwargs.get('s_enhancements', None)
out = self.pad_source_data(self.input_data,
self.pad_s1_start, self.pad_s1_end,
self.pad_s2_start, self.pad_s2_end,
self.pad_t_start, self.pad_t_end,
pad_width = self.get_padding()
out = self.pad_source_data(self.input_data, pad_width,
self.exogenous_data, exo_s_en)
self.input_data, self.exogenous_data = out

Expand All @@ -1108,14 +1101,14 @@ def file_paths(self):
reduced if there are single timesteps per file."""
file_paths = self._file_paths
if self.single_time_step_files:
file_paths = self._file_paths[self._ti_pad_slice]
file_paths = self._file_paths[self.ti_pad_slice]

return file_paths

@property
def temporal_pad_slice(self):
"""Get the low resolution temporal slice including padding."""
ti_pad_slice = self._ti_pad_slice
ti_pad_slice = self.ti_pad_slice
if self.single_time_step_files:
ti_pad_slice = slice(None)

Expand Down Expand Up @@ -1162,96 +1155,120 @@ def temporal_chunk_index(self):
"""Temporal index for the current chunk going through forward pass"""
return self.chunk_index // self.strategy.fwp_slicer.n_spatial_chunks

def get_chunk_kwargs(self, strategy, chunk_index):
"""Get node specific variables given an associated index
@property
def out_file(self):
"""Get output file name for the current chunk"""
return self.strategy.out_files[self.chunk_index]

Parameters
----------
strategy : ForwardPassStrategy
ForwardPassStrategy instance with information on data chunks to run
forward passes on.
chunk_index : int
Index to select chunk specific variables. This index selects the
corresponding file set, cropped_file_slice, padded_file_slice,
and padded/overlapping/cropped spatial slice for a spatiotemporal
chunk
"""
@property
def ti_slice(self):
"""Get ti slice for the current chunk"""
return self.strategy.ti_slices[self.temporal_chunk_index]

if chunk_index >= len(strategy):
msg = ('Index is out of bounds. There are '
f'{len(strategy.file_ids)} file chunks and the index '
f'requested was {chunk_index}.')
raise ValueError(msg)
@property
def ti_pad_slice(self):
"""Get padded ti slice for the current chunk"""
return self.strategy.ti_pad_slices[self.temporal_chunk_index]

s_chunk_index = self.spatial_chunk_index
t_chunk_index = self.temporal_chunk_index
@property
def lr_slice(self):
"""Get lr slice for the current chunk"""
return self.strategy.lr_slices[self.spatial_chunk_index]

self.out_file = strategy.out_files[chunk_index]
self._ti_pad_slice = strategy.ti_pad_slices[t_chunk_index]
self.ti_slice = strategy.ti_slices[t_chunk_index]
hr_crop_slices = strategy.fwp_slicer.hr_crop_slices[t_chunk_index]
@property
def lr_pad_slice(self):
"""Get padded lr slice for the current chunk"""
return self.strategy.lr_pad_slices[self.spatial_chunk_index]

self.cache_pattern = strategy.cache_pattern
if self.cache_pattern is not None:
self.cache_pattern = self.cache_pattern.replace(
'{temporal_chunk_index}', str(t_chunk_index))
self.cache_pattern = self.cache_pattern.replace(
'{spatial_chunk_index}', str(s_chunk_index))
@property
def hr_slice(self):
"""Get hr slice for the current chunk"""
return self.strategy.hr_slices[self.spatial_chunk_index]

self.raster_file = strategy.raster_file
if self.raster_file is not None:
self.raster_file = self.raster_file.replace(
'{spatial_chunk_index}', str(s_chunk_index))

self.ti_start = self.ti_slice.start or 0
self.ti_stop = self.ti_slice.stop or len(strategy.raw_time_index)
self.pad_t_start = int(np.maximum(0, (strategy.temporal_pad
- self.ti_start)))
self.pad_t_end = int(np.maximum(0, (strategy.temporal_pad
+ self.ti_stop
- len(strategy.raw_time_index))))

self.lr_slice = strategy.lr_slices[s_chunk_index]
self.lr_pad_slice = strategy.lr_pad_slices[s_chunk_index]
self.hr_slice = strategy.hr_slices[s_chunk_index]
self.hr_crop_slice = hr_crop_slices[s_chunk_index]
lr_crop_slices = strategy.fwp_slicer.s_lr_crop_slices
self.lr_crop_slice = lr_crop_slices[s_chunk_index]

self.s1_start = self.lr_slice[0].start or 0
self.s1_stop = self.lr_slice[0].stop or strategy.grid_shape[0]
self.pad_s1_start = int(np.maximum(0, (strategy.spatial_pad
- self.s1_start)))
self.pad_s1_end = int(np.maximum(0, (strategy.spatial_pad
+ self.s1_stop
- strategy.grid_shape[0])))

self.s2_start = self.lr_slice[1].start or 0
self.s2_stop = self.lr_slice[1].stop or strategy.grid_shape[1]
self.pad_s2_start = int(np.maximum(0, (strategy.spatial_pad
- self.s2_start)))
self.pad_s2_end = int(np.maximum(0, (strategy.spatial_pad
+ self.s2_stop
- strategy.grid_shape[1])))

self.data_shape = (*strategy.grid_shape,
len(strategy.raw_time_index[self._ti_pad_slice]))

self.chunk_shape = (
self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start,
self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start,
self.data_shape[2])
@property
def hr_crop_slice(self):
"""Get hr cropping slice for the current chunk"""
hr_crop_slices = self.strategy.fwp_slicer.hr_crop_slices[
self.temporal_chunk_index]
return hr_crop_slices[self.spatial_chunk_index]

self._file_paths = strategy.file_paths
self.max_workers = strategy.max_workers
self.pass_workers = strategy.pass_workers
self.output_workers = strategy.output_workers
self.exo_kwargs = strategy.exo_kwargs
self.single_time_step_files = strategy.single_time_step_files
@property
def lr_crop_slice(self):
"""Get lr cropping slice for the current chunk"""
lr_crop_slices = self.strategy.fwp_slicer.s_lr_crop_slices
return lr_crop_slices[self.spatial_chunk_index]

@property
def data_shape(self):
"""Get data shape for the current padded temporal chunk"""
return (*self.strategy.grid_shape,
len(self.strategy.raw_time_index[self.ti_pad_slice]))

@property
def chunk_shape(self):
"""Get shape for the current padded spatiotemporal chunk"""
return (self.lr_pad_slice[0].stop - self.lr_pad_slice[0].start,
self.lr_pad_slice[1].stop - self.lr_pad_slice[1].start,
self.data_shape[2])

@property
def cache_pattern(self):
"""Get cache pattern for the current chunk"""
cache_pattern = self.strategy.cache_pattern
if cache_pattern is not None:
cache_pattern = cache_pattern.replace(
'{temporal_chunk_index}', str(self.temporal_chunk_index))
cache_pattern = cache_pattern.replace(
'{spatial_chunk_index}', str(self.spatial_chunk_index))
return cache_pattern

@property
def raster_file(self):
"""Get raster file for the current spatial chunk"""
raster_file = self.strategy.raster_file
if raster_file is not None:
raster_file = raster_file.replace(
'{spatial_chunk_index}', str(self.spatial_chunk_index))
return raster_file

def get_padding(self):
"""Get padding for the current spatiotemporal chunk

Returns
-------
padding : tuple
Tuple of tuples with padding width for spatial and temporal
dimensions. Each tuple includes the start and end of padding for
that dimension. Ordering is spatial_1, spatial_2, temporal.
"""
ti_start = self.ti_slice.start or 0
ti_stop = self.ti_slice.stop or len(self.strategy.raw_time_index)
pad_t_start = int(np.maximum(0, (self.strategy.temporal_pad
- ti_start)))
pad_t_end = int(np.maximum(0, (self.strategy.temporal_pad
+ ti_stop
- len(self.strategy.raw_time_index))))

s1_start = self.lr_slice[0].start or 0
s1_stop = self.lr_slice[0].stop or self.strategy.grid_shape[0]
pad_s1_start = int(np.maximum(0, (self.strategy.spatial_pad
- s1_start)))
pad_s1_end = int(np.maximum(0, (self.strategy.spatial_pad
+ s1_stop
- self.strategy.grid_shape[0])))

s2_start = self.lr_slice[1].start or 0
s2_stop = self.lr_slice[1].stop or self.strategy.grid_shape[1]
pad_s2_start = int(np.maximum(0, (self.strategy.spatial_pad
- s2_start)))
pad_s2_end = int(np.maximum(0, (self.strategy.spatial_pad
+ s2_stop
- self.strategy.grid_shape[1])))
return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end),
(pad_t_start, pad_t_end))

@staticmethod
def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start,
pad_s2_end, pad_t_start, pad_t_end, exo_data,
def pad_source_data(input_data, pad_width, exo_data,
exo_s_enhancements, mode='reflect'):
"""Pad the edges of the source data from the data handler.

Expand All @@ -1265,20 +1282,10 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start,
passes for subsequent spatial stitching. This overlap will pad both
sides of the fwp_chunk_shape. Note that the first and last chunks
in any of the spatial dimension will not be padded.
pad_s1_start : int
How much padding to add to the beginning of the first spatial
dimension.
pad_s1_end : bool
How much padding to add to the end of the first spatial dimension.
pad_s2_start : int
How much padding to add to the beginning of the second spatial
dimension.
pad_s2_end : bool
How much padding to add to the end of the second spatial dimension.
pad_t_start : int
How much padding to add to the beginning of the temporal axis.
pad_t_end : bool
How much padding to add to the end of the temporal axis.
pad_width : tuple
Tuple of tuples with padding width for spatial and temporal
dimensions. Each tuple includes the start and end of padding for
that dimension. Ordering is spatial_1, spatial_2, temporal.
exo_data : None | list
List of exogenous data arrays for each step of the sup3r resolution
model. List entries can be None if not exo data is requested for a
Expand All @@ -1298,10 +1305,7 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start,
exo_data : list | None
Padded copy of exo_data input.
"""

pad_width = ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end),
(pad_t_start, pad_t_end), (0, 0))
out = np.pad(input_data, pad_width, mode=mode)
out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode)

logger.info('Padded input data shape from {} to {} using mode "{}" '
'with padding argument: {}'
Expand All @@ -1314,10 +1318,10 @@ def pad_source_data(input_data, pad_s1_start, pad_s1_end, pad_s2_start,
total_s_enhance = [s for s in total_s_enhance
if s is not None]
total_s_enhance = np.product(total_s_enhance)
pad_width = ((total_s_enhance * pad_s1_start,
total_s_enhance * pad_s1_end),
(total_s_enhance * pad_s2_start,
total_s_enhance * pad_s2_end), (0, 0))
pad_width = ((total_s_enhance * pad_width[0][0],
total_s_enhance * pad_width[0][1]),
(total_s_enhance * pad_width[1][0],
total_s_enhance * pad_width[1][1]), (0, 0))
exo_data[i] = np.pad(i_exo_data, pad_width, mode=mode)

return out, exo_data
Expand Down