Skip to content

Commit

Permalink
added extremes loss to lr loss
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Feb 14, 2024
1 parent ac0443d commit 5d266d9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
19 changes: 17 additions & 2 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,12 @@ class LowResLoss(tf.keras.losses.Loss):
high-resolution data pairs and then performing the pointwise content loss
on the low-resolution fields"""

EX_LOSS_METRICS = {'SpatialExtremesOnlyLoss': SpatialExtremesOnlyLoss,
'TemporalExtremesOnlyLoss': TemporalExtremesOnlyLoss,
}

def __init__(self, s_enhance=1, t_enhance=1, t_method='average',
tf_loss='MeanSquaredError'):
tf_loss='MeanSquaredError', ex_loss=None):
"""Initialize the loss with given weight
Parameters
Expand All @@ -539,13 +543,20 @@ def __init__(self, s_enhance=1, t_enhance=1, t_method='average',
The tensorflow loss function to operate on the low-res fields. Must
be the name of a loss class that can be retrieved from
``tf.keras.losses`` e.g., "MeanSquaredError" or "MeanAbsoluteError"
ex_loss : None | str
Optional additional loss metric evaluating the spatial or temporal
extremes of the high-res data. Can be "SpatialExtremesOnlyLoss" or
"TemporalExtremesOnlyLoss" (keys in ``EX_LOSS_METRICS``).
"""

super().__init__()
self._s_enhance = s_enhance
self._t_enhance = t_enhance
self._t_method = str(t_method).casefold()
self._tf_loss = getattr(tf.keras.losses, tf_loss)()
self._ex_loss = ex_loss
if self._ex_loss is not None:
self._ex_loss = self.EX_LOSS_METRICS[self._ex_loss]()

def _s_coarsen_4d_tensor(self, tensor):
"""Perform spatial coarsening on a 4D tensor of shape
Expand Down Expand Up @@ -611,6 +622,10 @@ def __call__(self, x1, x2):
assert x1.shape == x2.shape
s_only = len(x1.shape) == 4

ex_loss = tf.constant(0, dtype=x1.dtype)
if self._ex_loss is not None:
ex_loss = self._ex_loss(x1, x2)

if self._s_enhance > 1 and s_only:
x1 = self._s_coarsen_4d_tensor(x1)
x2 = self._s_coarsen_4d_tensor(x2)
Expand All @@ -627,4 +642,4 @@ def __call__(self, x1, x2):
x1 = self._t_coarsen_sample(x1)
x2 = self._t_coarsen_sample(x2)

return self._tf_loss(x1, x2)
return self._tf_loss(x1, x2) + ex_loss
7 changes: 7 additions & 0 deletions tests/utilities/test_loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,10 @@ def test_lr_loss():
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 4D spatial only with spatial extremes
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError',
ex_loss='SpatialExtremesOnlyLoss')
ex_loss = loss_obj(xtensor, ytensor)
assert ex_loss > loss

0 comments on commit 5d266d9

Please sign in to comment.