Skip to content

Commit

Permalink
added loss function that re-coarsens the synthetic and true high res …
Browse files Browse the repository at this point in the history
…fields and calculates loss on the low-res error
  • Loading branch information
grantbuster committed Feb 8, 2024
1 parent fc613de commit 37d2f56
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Loss metrics for Sup3r"""
"""Content loss metrics for Sup3r"""

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -512,3 +512,117 @@ def __call__(self, x1, x2):
"""
return (5 * self.st_ex_loss(x1, x2)
+ self.fft_weight * self.fft_loss(x1, x2)) / 6


class LowResLoss(tf.keras.losses.Loss):
"""Content loss that is calculated by coarsening the synthetic and true
high-resolution data pairs and then performing the pointwise content loss
on the low-resolution fields"""

def __init__(self, s_enhance=1, t_enhance=1, t_method='average',
tf_loss='MeanSquaredError'):
"""Initialize the loss with given weight
Parameters
----------
s_enhance : int
factor by which to coarsen spatial dimensions. 1 will keep the
spatial axes as high-res
t_enhance : int
factor by which to coarsen temporal dimension. 1 will keep the
temporal axes as high-res
t_method : str
Accepted options: [subsample, average]
Subsample will take every t_enhance-th time step, average will
average over t_enhance time steps
tf_loss : str
The tensorflow loss function to operate on the low-res fields. Must
be the name of a loss class that can be retrieved from
``tf.keras.losses`` e.g., "MeanSquaredError" or "MeanAbsoluteError"
"""

super().__init__()
self._s_enhance = s_enhance
self._t_enhance = t_enhance
self._t_method = str(t_method).casefold()
self._tf_loss = getattr(tf.keras.losses, tf_loss)()

def _s_coarsen_4d_tensor(self, tensor):
"""Perform spatial coarsening on a 4D tensor of shape
(n_obs, spatial_1, spatial_2, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor, shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3])
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _s_coarsen_5d_tensor(self, tensor):
"""Perform spatial coarsening on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor, shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3], shape[4])
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _t_coarsen_sample(self, tensor):
"""Perform temporal subsampling on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
assert len(tensor.shape) == 5
tensor = tensor[:, :, :, ::self.t_enhance, :]
return tensor

def _t_coarsen_avg(self, tensor):
"""Perform temporal coarsening on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
shape = tensor.shape
assert len(shape) == 5
tensor = tf.reshape(tensor, (shape[0], shape[1], shape[2], -1,
self._t_enhance, shape[4]))
tensor = tf.math.reduce_sum(tensor, axis=4) / self._t_enhance
return tensor

def __call__(self, x1, x2):
"""Custom content loss calculated on re-coarsened low-res fields
Parameters
----------
x1 : tf.tensor
Synthetic high-res generator output, shape is either of these:
(n_obs, spatial_1, spatial_2, features)
(n_obs, spatial_1, spatial_2, temporal, features)
x2 : tf.tensor
True high resolution data, shape is either of these:
(n_obs, spatial_1, spatial_2, features)
(n_obs, spatial_1, spatial_2, temporal, features)
Returns
-------
tf.tensor
0D tensor loss value
"""

assert x1.shape == x2.shape
s_only = len(x1.shape) == 4

if self._s_enhance > 1 and s_only:
x1 = self._s_coarsen_4d_tensor(x1)
x2 = self._s_coarsen_4d_tensor(x2)

elif self._s_enhance > 1 and not s_only:
x1 = self._s_coarsen_5d_tensor(x1)
x2 = self._s_coarsen_5d_tensor(x2)

if self._t_enhance > 1 and self._t_method == 'average':
x1 = self._t_coarsen_avg(x1)
x2 = self._t_coarsen_avg(x2)

if self._t_enhance > 1 and self._t_method == 'subsample':
x1 = self._t_coarsen_sample(x1)
x2 = self._t_coarsen_sample(x2)

return self._tf_loss(x1, x2)

0 comments on commit 37d2f56

Please sign in to comment.