Skip to content

Commit

Permalink
Merge pull request #185 from NREL/gb/lr_loss
Browse files Browse the repository at this point in the history
Gb/lr loss
  • Loading branch information
grantbuster committed Feb 14, 2024
2 parents fc613de + 91ba1ad commit e5b3ab6
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 84 deletions.
6 changes: 4 additions & 2 deletions sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def __init__(self,
discriminative model (spatial or spatiotemporal discriminator). Can
also be a str filepath to a .json config file containing the input
layers argument or a .pkl for a saved pre-trained model.
loss : str
loss : str | dict
Loss function class name from sup3r.utilities.loss_metrics
(prioritized) or tensorflow.keras.losses. Defaults to
tf.keras.losses.MeanSquaredError.
tf.keras.losses.MeanSquaredError. This can be provided as a dict
with kwargs for loss functions with extra parameters.
e.g. {'SpatialExtremesLoss': {'weight': 0.5}}
optimizer : tf.keras.optimizers.Optimizer | dict | None | str
Instantiated tf.keras.optimizers object or a dict optimizer config
from tf.keras.optimizers.get_config(). None defaults to Adam.
Expand Down
133 changes: 132 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Loss metrics for Sup3r"""
"""Content loss metrics for Sup3r"""

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -512,3 +512,134 @@ def __call__(self, x1, x2):
"""
return (5 * self.st_ex_loss(x1, x2)
+ self.fft_weight * self.fft_loss(x1, x2)) / 6


class LowResLoss(tf.keras.losses.Loss):
"""Content loss that is calculated by coarsening the synthetic and true
high-resolution data pairs and then performing the pointwise content loss
on the low-resolution fields"""

EX_LOSS_METRICS = {'SpatialExtremesOnlyLoss': SpatialExtremesOnlyLoss,
'TemporalExtremesOnlyLoss': TemporalExtremesOnlyLoss,
}

def __init__(self, s_enhance=1, t_enhance=1, t_method='average',
tf_loss='MeanSquaredError', ex_loss=None):
"""Initialize the loss with given weight
Parameters
----------
s_enhance : int
factor by which to coarsen spatial dimensions. 1 will keep the
spatial axes as high-res
t_enhance : int
factor by which to coarsen temporal dimension. 1 will keep the
temporal axes as high-res
t_method : str
Accepted options: [subsample, average]
Subsample will take every t_enhance-th time step, average will
average over t_enhance time steps
tf_loss : str
The tensorflow loss function to operate on the low-res fields. Must
be the name of a loss class that can be retrieved from
``tf.keras.losses`` e.g., "MeanSquaredError" or "MeanAbsoluteError"
ex_loss : None | str
Optional additional loss metric evaluating the spatial or temporal
extremes of the high-res data. Can be "SpatialExtremesOnlyLoss" or
"TemporalExtremesOnlyLoss" (keys in ``EX_LOSS_METRICS``).
"""

super().__init__()
self._s_enhance = s_enhance
self._t_enhance = t_enhance
self._t_method = str(t_method).casefold()
self._tf_loss = getattr(tf.keras.losses, tf_loss)()
self._ex_loss = ex_loss
if self._ex_loss is not None:
self._ex_loss = self.EX_LOSS_METRICS[self._ex_loss]()

def _s_coarsen_4d_tensor(self, tensor):
"""Perform spatial coarsening on a 4D tensor of shape
(n_obs, spatial_1, spatial_2, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor,
(shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3]))
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _s_coarsen_5d_tensor(self, tensor):
"""Perform spatial coarsening on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor,
(shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3], shape[4]))
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _t_coarsen_sample(self, tensor):
"""Perform temporal subsampling on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
assert len(tensor.shape) == 5
tensor = tensor[:, :, :, ::self._t_enhance, :]
return tensor

def _t_coarsen_avg(self, tensor):
"""Perform temporal coarsening on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
shape = tensor.shape
assert len(shape) == 5
tensor = tf.reshape(tensor, (shape[0], shape[1], shape[2], -1,
self._t_enhance, shape[4]))
tensor = tf.math.reduce_sum(tensor, axis=4) / self._t_enhance
return tensor

def __call__(self, x1, x2):
"""Custom content loss calculated on re-coarsened low-res fields
Parameters
----------
x1 : tf.tensor
Synthetic high-res generator output, shape is either of these:
(n_obs, spatial_1, spatial_2, features)
(n_obs, spatial_1, spatial_2, temporal, features)
x2 : tf.tensor
True high resolution data, shape is either of these:
(n_obs, spatial_1, spatial_2, features)
(n_obs, spatial_1, spatial_2, temporal, features)
Returns
-------
tf.tensor
0D tensor loss value
"""

assert x1.shape == x2.shape
s_only = len(x1.shape) == 4

ex_loss = tf.constant(0, dtype=x1.dtype)
if self._ex_loss is not None:
ex_loss = self._ex_loss(x1, x2)

if self._s_enhance > 1 and s_only:
x1 = self._s_coarsen_4d_tensor(x1)
x2 = self._s_coarsen_4d_tensor(x2)

elif self._s_enhance > 1 and not s_only:
x1 = self._s_coarsen_5d_tensor(x1)
x2 = self._s_coarsen_5d_tensor(x2)

if self._t_enhance > 1 and self._t_method == 'average':
x1 = self._t_coarsen_avg(x1)
x2 = self._t_coarsen_avg(x2)

if self._t_enhance > 1 and self._t_method == 'subsample':
x1 = self._t_coarsen_sample(x1)
x2 = self._t_coarsen_sample(x2)

return self._tf_loss(x1, x2) + ex_loss
76 changes: 0 additions & 76 deletions tests/training/test_custom_loss.py

This file was deleted.

145 changes: 145 additions & 0 deletions tests/utilities/test_loss_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
"""Test the basic training of super resolution GAN"""
import numpy as np
import tensorflow as tf

from sup3r.utilities.loss_metrics import (MmdMseLoss, CoarseMseLoss,
TemporalExtremesLoss, LowResLoss)
from sup3r.utilities.utilities import spatial_coarsening, temporal_coarsening


def test_mmd_loss():
"""Test content loss using mse + mmd for content loss."""

x = np.zeros((6, 10, 10, 8, 3))
y = np.zeros((6, 10, 10, 8, 3))
x[:, 7:9, 7:9, :, :] = 1
y[:, 2:5, 2:5, :, :] = 1

# distributions differing by only a small peak should give small mse and
# larger mmd
mse_fun = tf.keras.losses.MeanSquaredError()
mmd_mse_fun = MmdMseLoss()

mse = mse_fun(x, y)
mmd_plus_mse = mmd_mse_fun(x, y)

assert mmd_plus_mse > mse

x = np.random.rand(6, 10, 10, 8, 3)
x /= np.max(x)
y = np.random.rand(6, 10, 10, 8, 3)
y /= np.max(y)

# scaling the same distribution should give high mse and smaller mmd
mse = mse_fun(5 * x, x)
mmd_plus_mse = mmd_mse_fun(5 * x, x)

assert mmd_plus_mse < mse


def test_coarse_mse_loss():
"""Test the coarse MSE loss on spatial average data"""
x = np.random.uniform(0, 1, (6, 10, 10, 8, 3))
y = np.random.uniform(0, 1, (6, 10, 10, 8, 3))

mse_fun = tf.keras.losses.MeanSquaredError()
cmse_fun = CoarseMseLoss()

mse = mse_fun(x, y)
coarse_mse = cmse_fun(x, y)

assert isinstance(mse, tf.Tensor)
assert isinstance(coarse_mse, tf.Tensor)
assert mse.numpy().size == 1
assert coarse_mse.numpy().size == 1
assert mse.numpy() > 10 * coarse_mse.numpy()


def test_tex_loss():
"""Test custom TemporalExtremesLoss function that looks at min/max values
in the timeseries."""
loss_obj = TemporalExtremesLoss()

x = np.zeros((1, 1, 1, 72, 1))
y = np.zeros((1, 1, 1, 72, 1))

# loss should be dominated by special min/max values
x[..., 24, 0] = 20
y[..., 25, 0] = 25
loss = loss_obj(x, y)
assert loss.numpy() > 1.5

# loss should be dominated by special min/max values
x[..., 24, 0] = -20
y[..., 25, 0] = -25
loss = loss_obj(x, y)
assert loss.numpy() > 1.5


def test_lr_loss():
"""Test custom LowResLoss that re-coarsens synthetic and true high-res
fields and calculates pointwise loss on the low-res fields"""

# test w/o enhance
t_meth = 'average'
loss_obj = LowResLoss(s_enhance=1, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr = np.random.uniform(-1, 1, (3, 10, 10, 48, 2))
yarr = np.random.uniform(-1, 1, (3, 10, 10, 48, 2))
xtensor = tf.convert_to_tensor(xarr)
ytensor = tf.convert_to_tensor(yarr)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xtensor, ytensor))

# test 5D with s_enhance
s_enhance = 5
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 5D with s/t enhance
s_enhance = 5
t_enhance = 12
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=t_enhance,
t_method=t_meth, tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
xarr_lr = temporal_coarsening(xarr_lr, t_enhance=t_enhance, method=t_meth)
yarr_lr = temporal_coarsening(yarr_lr, t_enhance=t_enhance, method=t_meth)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 5D with subsample
t_meth = 'subsample'
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=t_enhance,
t_method=t_meth, tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
xarr_lr = temporal_coarsening(xarr_lr, t_enhance=t_enhance, method=t_meth)
yarr_lr = temporal_coarsening(yarr_lr, t_enhance=t_enhance, method=t_meth)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 4D spatial only
xarr = np.random.uniform(-1, 1, (3, 10, 10, 2))
yarr = np.random.uniform(-1, 1, (3, 10, 10, 2))
xtensor = tf.convert_to_tensor(xarr)
ytensor = tf.convert_to_tensor(yarr)
s_enhance = 5
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 4D spatial only with spatial extremes
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError',
ex_loss='SpatialExtremesOnlyLoss')
ex_loss = loss_obj(xtensor, ytensor)
assert ex_loss > loss

0 comments on commit e5b3ab6

Please sign in to comment.