Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added temporal then spatial model with test #140

Merged
merged 3 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .wind import WindGan
from .solar_cc import SolarCC
from .data_centric import Sup3rGanDC
from .multi_step import (MultiStepGan, SpatialThenTemporalGan,
from .multi_step import (MultiStepGan,
SpatialThenTemporalGan, TemporalThenSpatialGan,
MultiStepSurfaceMetGan, SolarMultiStepGan)
from .surface import SurfaceSpatialMetModel
from .linear import LinearInterp
Expand Down
257 changes: 190 additions & 67 deletions sup3r/models/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,8 @@ def model_params(self):


class SpatialThenTemporalBase(MultiStepGan):
"""A two-step model where the first step is a spatial-only enhancement on a
4D tensor and the second step is (spatio)temporal enhancement on a 5D
tensor.

NOTE: The low res input to the spatial enhancement should be a 4D tensor of
the shape (temporal, spatial_1, spatial_2, features) where temporal
(usually the observation index) is a series of sequential timesteps that
will be transposed to a 5D tensor of shape
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
2nd-step (spatio)temporal model.
"""A base class for spatial-then-temporal or temporal-then-spatial multi
step GANs
"""

def __init__(self, spatial_models, temporal_models):
Expand All @@ -272,22 +264,6 @@ def __init__(self, spatial_models, temporal_models):
self._spatial_models = spatial_models
self._temporal_models = temporal_models

@property
def models(self):
"""Get an ordered tuple of the Sup3rGan models that are part of this
MultiStepGan
"""
if isinstance(self.spatial_models, MultiStepGan):
spatial_models = self.spatial_models.models
else:
spatial_models = [self.spatial_models]
if isinstance(self.temporal_models, MultiStepGan):
temporal_models = self.temporal_models.models
else:
temporal_models = [self.temporal_models]

return (*spatial_models, *temporal_models)

@property
def spatial_models(self):
"""Get the MultiStepGan object for the spatial-only model(s)
Expand All @@ -308,6 +284,72 @@ def temporal_models(self):
"""
return self._temporal_models

@classmethod
def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True):
"""Load the GANs with its sub-networks from a previously saved-to
output directory.

Parameters
----------
spatial_model_dirs : str | list | tuple
An ordered list/tuple of one or more directories containing trained
+ saved Sup3rGan models created using the Sup3rGan.save() method.
This must contain only spatial models that input/output 4D
tensors.
temporal_model_dirs : str | list | tuple
An ordered list/tuple of one or more directories containing trained
+ saved Sup3rGan models created using the Sup3rGan.save() method.
This must contain only (spatio)temporal models that input/output 5D
tensors.
verbose : bool
Flag to log information about the loaded model.

Returns
-------
out : MultiStepGan
Returns a pretrained gan model that was previously saved to
model_dirs
"""
if isinstance(spatial_model_dirs, str):
spatial_model_dirs = [spatial_model_dirs]
if isinstance(temporal_model_dirs, str):
temporal_model_dirs = [temporal_model_dirs]

s_models = MultiStepGan.load(spatial_model_dirs, verbose=verbose)
t_models = MultiStepGan.load(temporal_model_dirs, verbose=verbose)

return cls(s_models, t_models)


class SpatialThenTemporalGan(SpatialThenTemporalBase):
"""A two-step GAN where the first step is a spatial-only enhancement on a
4D tensor and the second step is a (spatio)temporal enhancement on a 5D
tensor.

NOTE: The low res input to the spatial enhancement should be a 4D tensor of
the shape (temporal, spatial_1, spatial_2, features) where temporal
(usually the observation index) is a series of sequential timesteps that
will be transposed to a 5D tensor of shape
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
2nd-step (spatio)temporal GAN.
"""

@property
def models(self):
"""Get an ordered tuple of the Sup3rGan models that are part of this
MultiStepGan
"""
if isinstance(self.spatial_models, MultiStepGan):
spatial_models = self.spatial_models.models
else:
spatial_models = [self.spatial_models]
if isinstance(self.temporal_models, MultiStepGan):
temporal_models = self.temporal_models.models
else:
temporal_models = [self.temporal_models]

return (*spatial_models, *temporal_models)

@property
def meta(self):
"""Get a tuple of meta data dictionaries for all models
Expand All @@ -329,14 +371,14 @@ def meta(self):
@property
def training_features(self):
"""Get the list of input feature names that the first spatial
generative model in this SpatialThenTemporalBase model requires as
generative model in this SpatialThenTemporalGan model requires as
input."""
return self.spatial_models.training_features

@property
def output_features(self):
"""Get the list of output feature names that the last spatiotemporal
interpolation model in this SpatialThenTemporalBase model outputs."""
interpolation model in this SpatialThenTemporalGan model outputs."""
return self.temporal_models.output_features

def generate(self, low_res, norm_in=True, un_norm_out=True,
Expand Down Expand Up @@ -412,58 +454,139 @@ def generate(self, low_res, norm_in=True, un_norm_out=True,

return hi_res

@classmethod
def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True):
"""Load the GANs with its sub-networks from a previously saved-to
output directory.

class TemporalThenSpatialGan(SpatialThenTemporalBase):
"""A two-step GAN where the first step is a spatiotemporal enhancement on a
5D tensor and the second step is a spatial enhancement on a 4D tensor.
"""

@property
def models(self):
"""Get an ordered tuple of the Sup3rGan models that are part of this
MultiStepGan
"""
if isinstance(self.spatial_models, MultiStepGan):
spatial_models = self.spatial_models.models
else:
spatial_models = [self.spatial_models]
if isinstance(self.temporal_models, MultiStepGan):
temporal_models = self.temporal_models.models
else:
temporal_models = [self.temporal_models]

return (*temporal_models, *spatial_models)

@property
def meta(self):
"""Get a tuple of meta data dictionaries for all models

Returns
-------
tuple
"""
if isinstance(self.spatial_models, MultiStepGan):
spatial_models = self.spatial_models.meta
else:
spatial_models = [self.spatial_models.meta]
if isinstance(self.temporal_models, MultiStepGan):
temporal_models = self.temporal_models.meta
else:
temporal_models = [self.temporal_models.meta]

return (*temporal_models, *spatial_models)

@property
def training_features(self):
"""Get the list of input feature names that the first temporal
generative model in this TemporalThenSpatialGan model requires as
input."""
return self.temporal_models.training_features

@property
def output_features(self):
"""Get the list of output feature names that the last spatial
interpolation model in this TemporalThenSpatialGan model outputs."""
return self.spatial_models.output_features

def generate(self, low_res, norm_in=True, un_norm_out=True,
exogenous_data=None):
"""Use the generator model to generate high res data from low res
input. This is the public generate function.

Parameters
----------
spatial_model_dirs : str | list | tuple
An ordered list/tuple of one or more directories containing trained
+ saved Sup3rGan models created using the Sup3rGan.save() method.
This must contain only spatial models that input/output 4D
tensors.
temporal_model_dirs : str | list | tuple
An ordered list/tuple of one or more directories containing trained
+ saved Sup3rGan models created using the Sup3rGan.save() method.
This must contain only (spatio)temporal models that input/output 5D
tensors.
verbose : bool
Flag to log information about the loaded model.
low_res : np.ndarray
Low-resolution input data, a 5D array of shape:
(1, spatial_1, spatial_2, n_temporal, n_features)
norm_in : bool
Flag to normalize low_res input data if the self.means,
self.stdevs attributes are available. The generator should always
received normalized data with mean=0 stdev=1.
un_norm_out : bool
Flag to un-normalize synthetically generated output data to physical
units
exogenous_data : list
List of arrays of exogenous_data with length equal to the
number of model steps. e.g. If we want to include topography as
an exogenous feature in a temporal + spatial multistep model then
we need to provide a list of length=2 with topography at the low
spatial resolution and at the high resolution. If we include more
than one exogenous feature the ordering must be consistent.
Each array in the list has 3D or 4D shape:
(spatial_1, spatial_2, n_features)
(temporal, spatial_1, spatial_2, n_features)

Returns
-------
out : MultiStepGan
Returns a pretrained gan model that was previously saved to
model_dirs
hi_res : ndarray
Synthetically generated high-resolution data output from the 2nd
step (spatio)temporal GAN with a 5D array shape:
(1, spatial_1, spatial_2, n_temporal, n_features)
"""
if isinstance(spatial_model_dirs, str):
spatial_model_dirs = [spatial_model_dirs]
if isinstance(temporal_model_dirs, str):
temporal_model_dirs = [temporal_model_dirs]
logger.debug('Data input to the 1st step (spatio)temporal '
'enhancement has shape {}'.format(low_res.shape))
s_exogenous = None
if exogenous_data is not None:
s_exogenous = exogenous_data[len(self.temporal_models):]

s_models = MultiStepGan.load(spatial_model_dirs, verbose=verbose)
t_models = MultiStepGan.load(temporal_model_dirs, verbose=verbose)
assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!'

return cls(s_models, t_models)
try:
hi_res = self.temporal_models.generate(
low_res, norm_in=norm_in, un_norm_out=True,
exogenous_data=exogenous_data)
except Exception as e:
msg = ('Could not run the 1st step (spatio)temporal GAN on input '
'shape {}'.format(low_res.shape))
logger.exception(msg)
raise RuntimeError(msg) from e

logger.debug('Data output from the 1st step (spatio)temporal '
'enhancement has shape {}'.format(hi_res.shape))
hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3))
logger.debug('Data from the 1st step (spatio)temporal enhancement has '
'been reshaped to {}'.format(hi_res.shape))

class SpatialThenTemporalGan(SpatialThenTemporalBase):
"""A two-step GAN where the first step is a spatial-only enhancement on a
4D tensor and the second step is a (spatio)temporal enhancement on a 5D
tensor.
try:
hi_res = self.spatial_models.generate(
hi_res, norm_in=True, un_norm_out=un_norm_out,
exogenous_data=s_exogenous)
except Exception as e:
msg = ('Could not run the 2nd step spatial GAN on input '
'shape {}'.format(low_res.shape))
logger.exception(msg)
raise RuntimeError(msg) from e

NOTE: The low res input to the spatial enhancement should be a 4D tensor of
the shape (temporal, spatial_1, spatial_2, features) where temporal
(usually the observation index) is a series of sequential timesteps that
will be transposed to a 5D tensor of shape
(1, spatial_1, spatial_2, temporal, features) tensor and then fed to the
2nd-step (spatio)temporal GAN.
"""
hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3))
hi_res = np.expand_dims(hi_res, axis=0)

logger.debug('Final multistep GAN output has shape: {}'
.format(hi_res.shape))

return hi_res


class MultiStepSurfaceMetGan(SpatialThenTemporalBase):
class MultiStepSurfaceMetGan(SpatialThenTemporalGan):
"""A two-step GAN where the first step is a spatial-only enhancement on a
4D tensor of near-surface temperature and relative humidity data, and the
second step is a (spatio)temporal enhancement on a 5D tensor.
Expand Down Expand Up @@ -612,7 +735,7 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel',
return cls(s_models, t_models)


class SolarMultiStepGan(SpatialThenTemporalBase):
class SolarMultiStepGan(SpatialThenTemporalGan):
"""Special multi step model for solar clearsky ratio super resolution.

This model takes in two parallel models for wind-only and solar-only
Expand Down
6 changes: 3 additions & 3 deletions sup3r/models/wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,11 @@ def init_weights(self, lr_shape, hr_shape, device=None):
device = self.default_device

logger.info('Initializing model weights on device "{}"'.format(device))
low_res = np.random.uniform(0, 1, lr_shape).astype(np.float32)
hi_res = np.random.uniform(0, 1, hr_shape).astype(np.float32)
low_res = np.ones(lr_shape).astype(np.float32)
hi_res = np.ones(hr_shape).astype(np.float32)

hr_topo_shape = hr_shape[:-1] + (1,)
hr_topo = np.random.uniform(0, 1, hr_topo_shape).astype(np.float32)
hr_topo = np.ones(hr_topo_shape).astype(np.float32)

with tf.device(device):
_ = self._tf_generate(low_res, hr_topo)
Expand Down
35 changes: 34 additions & 1 deletion tests/test_multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import tempfile

from sup3r import CONFIG_DIR
from sup3r.models import (Sup3rGan, MultiStepGan, SpatialThenTemporalGan,
from sup3r.models import (Sup3rGan, MultiStepGan,
SpatialThenTemporalGan, TemporalThenSpatialGan,
SolarMultiStepGan, LinearInterp)

FEATURES = ['U_100m', 'V_100m']
Expand Down Expand Up @@ -129,6 +130,38 @@ def test_spatial_then_temporal_gan():
assert out.shape == (1, 60, 60, 16, 2)


def test_temporal_then_spatial_gan():
"""Test the 2-step temporal-then-spatial GAN"""
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json')
model1 = Sup3rGan(fp_gen, fp_disc)
_ = model1.generate(np.ones((4, 10, 10, len(FEATURES))))

fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')
model2 = Sup3rGan(fp_gen, fp_disc)
_ = model2.generate(np.ones((4, 10, 10, 6, len(FEATURES))))

model1.set_norm_stats([0.1, 0.2], [0.04, 0.02])
model2.set_norm_stats([0.3, 0.9], [0.02, 0.07])
model1.set_model_params(training_features=FEATURES,
output_features=FEATURES)
model2.set_model_params(training_features=FEATURES,
output_features=FEATURES)

with tempfile.TemporaryDirectory() as td:
fp1 = os.path.join(td, 'model1')
fp2 = os.path.join(td, 'model2')
model1.save(fp1)
model2.save(fp2)

ms_model = TemporalThenSpatialGan.load(fp1, fp2)

x = np.ones((1, 10, 10, 4, len(FEATURES)))
out = ms_model.generate(x)
assert out.shape == (1, 60, 60, 16, 2)


def test_spatial_gan_then_linear_interp():
"""Test the 2-step spatial GAN then linear spatiotemporal interpolation"""
fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json')
Expand Down