Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/trh loss #142

Merged
merged 12 commits into from
Jan 20, 2023
Merged
Next Next commit
added special Sup3rCC TRH loss
  • Loading branch information
grantbuster committed Jan 12, 2023
commit 4b6643cd435c12b49ab5d8b89ef0a3bff8633a0f
38 changes: 37 additions & 1 deletion sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Loss metrics for Sup3r"""

from tensorflow.keras.losses import MeanSquaredError
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
import tensorflow as tf


Expand Down Expand Up @@ -171,3 +171,39 @@ def __call__(self, x1, x2):
x1_coarse = tf.reduce_mean(x1, axis=(1, 2))
x2_coarse = tf.reduce_mean(x2, axis=(1, 2))
return self.MSE_LOSS(x1_coarse, x2_coarse)


class TRHLoss(tf.keras.losses.Loss):
"""Loss class for Temperature and Relative Humidity Sup3rCC GAN that
encourages accuracy of the min/max values in the timeseries"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom Sup3rCC content loss function for Temp + RH

Parameters
----------
x1 : tf.tensor
synthetic generator output
(n_observations, spatial_1, spatial_2, temporal, features)
x2 : tf.tensor
high resolution data
(n_observations, spatial_1, spatial_2, temporal, features)

Returns
-------
tf.tensor
0D tensor with loss value
"""
x1_min = tf.reduce_min(x1, axis=3)
x2_min = tf.reduce_min(x2, axis=3)

x1_max = tf.reduce_max(x1, axis=3)
x2_max = tf.reduce_max(x2, axis=3)

mae = self.MAE_LOSS(x1, x2)
mae_min = self.MAE_LOSS(x1_min, x2_min)
mae_max = self.MAE_LOSS(x1_max, x2_max)

return mae + mae_min + mae_max
25 changes: 24 additions & 1 deletion tests/test_custom_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import tensorflow as tf

from sup3r.utilities.loss_metrics import MmdMseLoss, CoarseMseLoss
from sup3r.utilities.loss_metrics import MmdMseLoss, CoarseMseLoss, TRHLoss


def test_mmd_loss():
Expand Down Expand Up @@ -52,3 +52,26 @@ def test_coarse_mse_loss():
assert mse.numpy().size == 1
assert coarse_mse.numpy().size == 1
assert mse.numpy() > 10 * coarse_mse.numpy()


def test_trh_loss():
"""Test custom Sup3rCC Temp + RH loss function that looks at min/max values
in the timeseries."""
trh_loss = TRHLoss()

x = np.zeros((1, 1, 1, 72, 1))
y = np.zeros((1, 1, 1, 72, 1))

# loss should be dominated by special min/max values
x[..., 24, 0] = 20
y[..., 25, 0] = 25
loss = trh_loss(x, y)
assert loss.numpy() > 5
assert loss.numpy() < 6

# loss should be dominated by special min/max values
x[..., 24, 0] = -10
y[..., 25, 0] = -15
loss = trh_loss(x, y)
assert loss.numpy() > 5
assert loss.numpy() < 6