Skip to content

Commit

Permalink
Merge pull request #141 from NREL/bnb/dev
Browse files Browse the repository at this point in the history
Bnb/dev
  • Loading branch information
bnb32 committed Jan 11, 2023
2 parents 93c56f9 + 6bbe02e commit 0e83ada
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 365 deletions.
14 changes: 7 additions & 7 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,9 +909,9 @@ class AbstractWindInterface(ABC):
Abstract class to define the required training interface
for Sup3r wind model subclasses
"""

# pylint: disable=E0211
@staticmethod
def set_model_params_wind(**kwargs):
def set_model_params(**kwargs):
"""Set parameters used for training the model
Parameters
Expand Down Expand Up @@ -1002,8 +1002,8 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True):

return hi_res_topo

def generate_wind(self, low_res, norm_in=True, un_norm_out=True,
exogenous_data=None):
def generate(self, low_res, norm_in=True, un_norm_out=True,
exogenous_data=None):
"""Use the generator model to generate high res data from low res
input. This is the public generate function.
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def generate_wind(self, low_res, norm_in=True, un_norm_out=True,
return hi_res

@tf.function
def _tf_generate_wind(self, low_res, hi_res_topo):
def _tf_generate(self, low_res, hi_res_topo):
"""Use the generator model to generate high res data from los res input
Parameters
Expand Down Expand Up @@ -1115,8 +1115,8 @@ def _tf_generate_wind(self, low_res, hi_res_topo):
return hi_res

@tf.function()
def get_single_grad_wind(self, low_res, hi_res_true, training_weights,
device_name=None, **calc_loss_kwargs):
def get_single_grad(self, low_res, hi_res_true, training_weights,
device_name=None, **calc_loss_kwargs):
"""Run gradient descent for one mini-batch of (low_res, hi_res_true),
do not update weights, just return gradient details.
Expand Down
108 changes: 3 additions & 105 deletions sup3r/models/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.getLogger(__name__)


class WindGan(Sup3rGan, AbstractWindInterface):
class WindGan(AbstractWindInterface, Sup3rGan):
"""Wind super resolution GAN with handling of low and high res topography
inputs.
Expand Down Expand Up @@ -69,110 +69,8 @@ def set_model_params(self, **kwargs):
Keyword arguments including 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing'
"""
kwargs = self.set_model_params_wind(**kwargs)
super().set_model_params(**kwargs)

def generate(self, low_res, norm_in=True, un_norm_out=True,
exogenous_data=None):
"""Use the generator model to generate high res data from low res
input. This is the public generate function.
Parameters
----------
low_res : np.ndarray
Low-resolution input data, usually a 4D or 5D array of shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
norm_in : bool
Flag to normalize low_res input data if the self._means,
self._stdevs attributes are available. The generator should always
received normalized data with mean=0 stdev=1. This also normalizes
hi_res_topo.
un_norm_out : bool
Flag to un-normalize synthetically generated output data to physical
units
exogenous_data : ndarray | list | None
Exogenous data for topography inputs. The first entry in this list
(or only entry) is a low-resolution topography array that can be
concatenated to the low_res input array. The second entry is
high-resolution topography (either 2D or 4D/5D depending on if
spatial or spatiotemporal super res).
Returns
-------
hi_res : ndarray
Synthetically generated high-resolution data, usually a 4D or 5D
array with shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""
return self.generate_wind(low_res, norm_in,
un_norm_out, exogenous_data)

@tf.function
def _tf_generate(self, low_res, hi_res_topo):
"""Use the generator model to generate high res data from los res input
Parameters
----------
low_res : np.ndarray
Real low-resolution data. The generator should always
received normalized data with mean=0 stdev=1.
hi_res_topo : np.ndarray
This should be a 4D array for spatial enhancement model or 5D array
for a spatiotemporal enhancement model (obs, spatial_1, spatial_2,
(temporal), features) corresponding to the high-resolution
spatial_1 and spatial_2. This data will be input to the custom
phygnn Sup3rAdder or Sup3rConcat layer if found in the generative
network. This differs from the exogenous_data input in that
exogenous_data always matches the low-res input.
Returns
-------
hi_res : tf.Tensor
Synthetically generated high-resolution data
"""
return self._tf_generate_wind(low_res, hi_res_topo)

@tf.function()
def get_single_grad(self, low_res, hi_res_true, training_weights,
device_name=None, **calc_loss_kwargs):
"""Run gradient descent for one mini-batch of (low_res, hi_res_true),
do not update weights, just return gradient details.
Parameters
----------
low_res : np.ndarray
Real low-resolution data in a 4D or 5D array:
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
hi_res_true : np.ndarray
Real high-resolution data in a 4D or 5D array:
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
training_weights : list
A list of layer weights that are to-be-trained based on the
current loss weight values.
device_name : None | str
Optional tensorflow device name for GPU placement. Note that if a
GPU is available, variables will be placed on that GPU even if
device_name=None.
calc_loss_kwargs : dict
Kwargs to pass to the self.calc_loss() method
Returns
-------
grad : list
a list or nested structure of Tensors (or IndexedSlices, or None,
or CompositeTensor) representing the gradients for the
training_weights
loss_details : dict
Namespace of the breakdown of loss components
"""
return self.get_single_grad_wind(low_res, hi_res_true,
training_weights,
device_name=device_name,
**calc_loss_kwargs)
AbstractWindInterface.set_model_params(**kwargs)
Sup3rGan.set_model_params(self, **kwargs)

@tf.function
def calc_loss(self, hi_res_true, hi_res_gen, **kwargs):
Expand Down
108 changes: 3 additions & 105 deletions sup3r/models/wind_conditional_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)


class WindCondMom(Sup3rCondMom, AbstractWindInterface):
class WindCondMom(AbstractWindInterface, Sup3rCondMom):
"""Wind conditional moment estimator with handling of low and
high res topography inputs.
Expand All @@ -33,110 +33,8 @@ def set_model_params(self, **kwargs):
Keyword arguments including 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing'
"""
kwargs = self.set_model_params_wind(**kwargs)
super().set_model_params(**kwargs)

def generate(self, low_res, norm_in=True, un_norm_out=True,
exogenous_data=None):
"""Use the generator model to generate high res data from low res
input. This is the public generate function.
Parameters
----------
low_res : np.ndarray
Low-resolution input data, usually a 4D or 5D array of shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
norm_in : bool
Flag to normalize low_res input data if the self._means,
self._stdevs attributes are available. The generator should always
received normalized data with mean=0 stdev=1. This also normalizes
hi_res_topo.
un_norm_out : bool
Flag to un-normalize synthetically generated output data to physical
units
exogenous_data : ndarray | list | None
Exogenous data for topography inputs. The first entry in this list
(or only entry) is a low-resolution topography array that can be
concatenated to the low_res input array. The second entry is
high-resolution topography (either 2D or 4D/5D depending on if
spatial or spatiotemporal super res).
Returns
-------
hi_res : ndarray
Synthetically generated high-resolution data, usually a 4D or 5D
array with shape:
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""
return self.generate_wind(low_res, norm_in,
un_norm_out, exogenous_data)

@tf.function
def _tf_generate(self, low_res, hi_res_topo):
"""Use the generator model to generate high res data from los res input
Parameters
----------
low_res : np.ndarray
Real low-resolution data. The generator should always
received normalized data with mean=0 stdev=1.
hi_res_topo : np.ndarray
This should be a 4D array for spatial enhancement model or 5D array
for a spatiotemporal enhancement model (obs, spatial_1, spatial_2,
(temporal), features) corresponding to the high-resolution
spatial_1 and spatial_2. This data will be input to the custom
phygnn Sup3rAdder or Sup3rConcat layer if found in the generative
network. This differs from the exogenous_data input in that
exogenous_data always matches the low-res input.
Returns
-------
hi_res : tf.Tensor
Synthetically generated high-resolution data
"""
return self._tf_generate_wind(low_res, hi_res_topo)

@tf.function()
def get_single_grad(self, low_res, hi_res_true, training_weights,
device_name=None, **calc_loss_kwargs):
"""Run gradient descent for one mini-batch of (low_res, hi_res_true),
do not update weights, just return gradient details.
Parameters
----------
low_res : np.ndarray
Real low-resolution data in a 4D or 5D array:
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
hi_res_true : np.ndarray
Real high-resolution data in a 4D or 5D array:
(n_observations, spatial_1, spatial_2, features)
(n_observations, spatial_1, spatial_2, temporal, features)
training_weights : list
A list of layer weights that are to-be-trained based on the
current loss weight values.
device_name : None | str
Optional tensorflow device name for GPU placement. Note that if a
GPU is available, variables will be placed on that GPU even if
device_name=None.
calc_loss_kwargs : dict
Kwargs to pass to the self.calc_loss() method
Returns
-------
grad : list
a list or nested structure of Tensors (or IndexedSlices, or None,
or CompositeTensor) representing the gradients for the
training_weights
loss_details : dict
Namespace of the breakdown of loss components
"""
return self.get_single_grad_wind(low_res, hi_res_true,
training_weights,
device_name=device_name,
**calc_loss_kwargs)
AbstractWindInterface.set_model_params(**kwargs)
Sup3rCondMom.set_model_params(self, **kwargs)

@tf.function
def calc_loss(self, hi_res_true, hi_res_gen, mask, **kwargs):
Expand Down

0 comments on commit 0e83ada

Please sign in to comment.