Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/trh loss #142

Merged
merged 12 commits into from
Jan 20, 2023
Merged
Prev Previous commit
Next Next commit
renamed overly specific loss function
  • Loading branch information
grantbuster committed Jan 13, 2023
commit ea31c8585aa84d4b1ac96f846bbb05afb6116024
8 changes: 4 additions & 4 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@ def __call__(self, x1, x2):
return self.MSE_LOSS(x1_coarse, x2_coarse)


class TRHLoss(tf.keras.losses.Loss):
"""Loss class for Temperature and Relative Humidity Sup3rCC GAN that
encourages accuracy of the min/max values in the timeseries"""
class TemporalExtremesLoss(tf.keras.losses.Loss):
"""Loss class that encourages accuracy of the min/max values in the
timeseries"""

MAE_LOSS = MeanAbsoluteError()

def __call__(self, x1, x2):
"""Custom Sup3rCC content loss function for Temp + RH
"""Custom content loss that encourages temporal min/max accuracy

Parameters
----------
Expand Down
13 changes: 7 additions & 6 deletions tests/test_custom_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import numpy as np
import tensorflow as tf

from sup3r.utilities.loss_metrics import MmdMseLoss, CoarseMseLoss, TRHLoss
from sup3r.utilities.loss_metrics import (MmdMseLoss, CoarseMseLoss,
TemporalExtremesLoss)


def test_mmd_loss():
Expand Down Expand Up @@ -54,24 +55,24 @@ def test_coarse_mse_loss():
assert mse.numpy() > 10 * coarse_mse.numpy()


def test_trh_loss():
"""Test custom Sup3rCC Temp + RH loss function that looks at min/max values
def test_tex_loss():
"""Test custom TemporalExtremesLoss function that looks at min/max values
in the timeseries."""
trh_loss = TRHLoss()
loss_obj = TemporalExtremesLoss()

x = np.zeros((1, 1, 1, 72, 1))
y = np.zeros((1, 1, 1, 72, 1))

# loss should be dominated by special min/max values
x[..., 24, 0] = 20
y[..., 25, 0] = 25
loss = trh_loss(x, y)
loss = loss_obj(x, y)
assert loss.numpy() > 5
assert loss.numpy() < 6

# loss should be dominated by special min/max values
x[..., 24, 0] = -10
y[..., 25, 0] = -15
loss = trh_loss(x, y)
loss = loss_obj(x, y)
assert loss.numpy() > 5
assert loss.numpy() < 6