Skip to content

Commit

Permalink
PR updates
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 12, 2023
1 parent 40155ae commit 53d8bc7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 34 deletions.
16 changes: 12 additions & 4 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
Low-resolution input data, usually a 4D or 5D array of shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
exogenous_data : ExoData | None
exogenous_data : dict | ExoData | None
Special dictionary (class:`ExoData`) of exogenous feature data with
entries describing whether features should be combined at input, a
mid network layer, or with output. This doesn't have to include
Expand All @@ -241,6 +241,10 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
if exogenous_data is None:
return low_res

if (not isinstance(exogenous_data, ExoData)
and exogenous_data is not None):
exogenous_data = ExoData(exogenous_data)

training_features = ([] if self.training_features is None
else self.training_features)
fnum_diff = len(training_features) - low_res.shape[-1]
Expand All @@ -267,7 +271,7 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
High-resolution output data, usually a 4D or 5D array of shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
exogenous_data : dict | None
exogenous_data : dict | ExoData | None
Special dictionary (class:`ExoData`) of exogenous feature data with
entries describing whether features should be combined at input, a
mid network layer, or with output. This doesn't have to include
Expand All @@ -284,6 +288,10 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
if exogenous_data is None:
return hi_res

if (not isinstance(exogenous_data, ExoData)
and exogenous_data is not None):
exogenous_data = ExoData(exogenous_data)

output_features = ([] if self.output_features is None
else self.output_features)
fnum_diff = len(output_features) - hi_res.shape[-1]
Expand Down Expand Up @@ -1260,8 +1268,8 @@ def generate(self,
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""
if (isinstance(exogenous_data, dict)
and not isinstance(exogenous_data, ExoData)):
if (not isinstance(exogenous_data, ExoData)
and exogenous_data is not None):
exogenous_data = ExoData(exogenous_data)

low_res = self._combine_fwp_input(low_res, exogenous_data)
Expand Down
4 changes: 2 additions & 2 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,7 @@ def load_exo_data(self):
class:`ExoData` object composed of multiple
class:`SingleExoDataStep` objects.
"""
data = []
data = {}
exo_data = None
if self.exo_kwargs:
self.features = [f for f in self.features
Expand All @@ -1144,7 +1144,7 @@ def load_exo_data(self):
sig = signature(ExogenousDataHandler)
exo_kwargs = {k: v for k, v in exo_kwargs.items()
if k in sig.parameters}
data += ExogenousDataHandler(**exo_kwargs).data
data.update(ExogenousDataHandler(**exo_kwargs).data)
exo_data = ExoData(data)
return exo_data

Expand Down
45 changes: 19 additions & 26 deletions sup3r/preprocessing/data_handling/exogenous_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,24 @@ def __init__(self, steps):
Parameters
----------
steps : list | dict
List of SingleExoDataStep objects or a feature dictionary with list
of steps for each feature
steps : dict
Dictionary with feature keys each with entries describing whether
features should be combined at input, a mid network layer, or with
output. e.g.
{'topography': {'steps': [
{'combine_type': 'input', 'model': 0, 'data': ...,
'resolution': ...},
{'combine_type': 'layer', 'model': 0, 'data': ...,
'resolution': ...}]}}
Each array in in 'data' key has 3D or 4D shape:
(spatial_1, spatial_2, 1)
(spatial_1, spatial_2, n_temporal, 1)
"""
if isinstance(steps, list):
for step in steps:
self.append(step.feature, step)
elif isinstance(steps, dict):
if isinstance(steps, dict):
for k, v in steps.items():
self.__setitem__(k, v)
else:
msg = ('ExoData must be initialized with a dictionary of features '
'or list of SingleExoDataStep objects.')
msg = 'ExoData must be initialized with a dictionary of features.'
logger.error(msg)
raise ValueError(msg)

Expand Down Expand Up @@ -117,18 +122,6 @@ def split_exo_dict(self, split_step):
spatial models and temporal models split_step should be
len(spatial_models). If this is for a TemporalThenSpatial model
split_step should be len(temporal_models).
exogenous_data : dict
Dictionary of exogenous feature data with entries describing
whether features should be combined at input, a mid network layer,
or with output. e.g.
{'topography': {'steps': [
{'combine_type': 'input', 'model': 0, 'data': ...,
'resolution': ...},
{'combine_type': 'layer', 'model': 0, 'data': ...,
'resolution': ...}]}}
Each array in in 'data' key has 3D or 4D shape:
(spatial_1, spatial_2, 1)
(spatial_1, spatial_2, n_temporal, 1)
Returns
-------
Expand Down Expand Up @@ -306,7 +299,7 @@ def __init__(self,
self.input_handler = input_handler
self.cache_data = cache_data
self.cache_dir = cache_dir
self.data = []
self.data = {feature: {'steps': []}}

self.input_check()
agg_enhance = self._get_all_agg_and_enhancement()
Expand Down Expand Up @@ -341,16 +334,16 @@ def __init__(self,
t_agg_factor=t_agg_factor)
step = SingleExoDataStep(feature, steps[i]['combine_type'],
steps[i]['model'], data)
self.data.append(step)
self.data[feature]['steps'].append(step)
else:
msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}."
f" Received {feature}.")
raise NotImplementedError(msg)
shapes = [None if d is None else d['data'].shape
for d in self.data]
shapes = [None if step is None else step.shape
for step in self.data[feature]['steps']]
logger.info(
'Got exogenous_data of length {} with shapes: {}'.format(
len(self.data), shapes))
len(self.data[feature]['steps']), shapes))

def input_check(self):
"""Make sure agg factors are provided or exo_resolution and models are
Expand Down
5 changes: 3 additions & 2 deletions tests/data_handling/test_exo_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_exo_cache(feature):
target=TARGET, shape=SHAPE,
input_handler='DataHandlerNCforCC',
cache_dir=os.path.join(td, 'exo_cache'))
for i, arr in enumerate(base.data):
for i, arr in enumerate(base.data[feature]['steps']):
assert arr.shape[0] == SHAPE[0] * S_ENHANCE[i]
assert arr.shape[1] == SHAPE[1] * S_ENHANCE[i]

Expand All @@ -58,5 +58,6 @@ def test_exo_cache(feature):
cache_dir=os.path.join(td, 'exo_cache'))
assert len(os.listdir(f'{td}/exo_cache')) == 2

for arr1, arr2 in zip(base.data, cache.data):
for arr1, arr2 in zip(base.data[feature]['steps'],
cache.data[feature]['steps']):
assert np.allclose(arr1['data'], arr2['data'])

0 comments on commit 53d8bc7

Please sign in to comment.