Skip to content

Commit

Permalink
added lr loss tests
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Feb 8, 2024
1 parent 37d2f56 commit b8f6d42
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 85 deletions.
20 changes: 11 additions & 9 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,29 +551,31 @@ def _s_coarsen_4d_tensor(self, tensor):
"""Perform spatial coarsening on a 4D tensor of shape
(n_obs, spatial_1, spatial_2, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor, shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3])
tensor = tf.reshape(tensor,
(shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3]))
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _s_coarsen_5d_tensor(self, tensor):
"""Perform spatial coarsening on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
shape = tensor.shape
tensor = tf.reshape(tensor, shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3], shape[4])
tensor = tf.reshape(tensor,
(shape[0],
shape[1] // self._s_enhance, self._s_enhance,
shape[2] // self._s_enhance, self._s_enhance,
shape[3], shape[4]))
tensor = tf.math.reduce_sum(tensor, axis=(2, 4)) / self._s_enhance**2
return tensor

def _t_coarsen_sample(self, tensor):
"""Perform temporal subsampling on a 5D tensor of shape
(n_obs, spatial_1, spatial_2, time, features)"""
assert len(tensor.shape) == 5
tensor = tensor[:, :, :, ::self.t_enhance, :]
tensor = tensor[:, :, :, ::self._t_enhance, :]
return tensor

def _t_coarsen_avg(self, tensor):
Expand Down
76 changes: 0 additions & 76 deletions tests/training/test_custom_loss.py

This file was deleted.

138 changes: 138 additions & 0 deletions tests/utilities/test_loss_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# -*- coding: utf-8 -*-
"""Test the basic training of super resolution GAN"""
import numpy as np
import tensorflow as tf

from sup3r.utilities.loss_metrics import (MmdMseLoss, CoarseMseLoss,
TemporalExtremesLoss, LowResLoss)
from sup3r.utilities.utilities import spatial_coarsening, temporal_coarsening


def test_mmd_loss():
"""Test content loss using mse + mmd for content loss."""

x = np.zeros((6, 10, 10, 8, 3))
y = np.zeros((6, 10, 10, 8, 3))
x[:, 7:9, 7:9, :, :] = 1
y[:, 2:5, 2:5, :, :] = 1

# distributions differing by only a small peak should give small mse and
# larger mmd
mse_fun = tf.keras.losses.MeanSquaredError()
mmd_mse_fun = MmdMseLoss()

mse = mse_fun(x, y)
mmd_plus_mse = mmd_mse_fun(x, y)

assert mmd_plus_mse > mse

x = np.random.rand(6, 10, 10, 8, 3)
x /= np.max(x)
y = np.random.rand(6, 10, 10, 8, 3)
y /= np.max(y)

# scaling the same distribution should give high mse and smaller mmd
mse = mse_fun(5 * x, x)
mmd_plus_mse = mmd_mse_fun(5 * x, x)

assert mmd_plus_mse < mse


def test_coarse_mse_loss():
"""Test the coarse MSE loss on spatial average data"""
x = np.random.uniform(0, 1, (6, 10, 10, 8, 3))
y = np.random.uniform(0, 1, (6, 10, 10, 8, 3))

mse_fun = tf.keras.losses.MeanSquaredError()
cmse_fun = CoarseMseLoss()

mse = mse_fun(x, y)
coarse_mse = cmse_fun(x, y)

assert isinstance(mse, tf.Tensor)
assert isinstance(coarse_mse, tf.Tensor)
assert mse.numpy().size == 1
assert coarse_mse.numpy().size == 1
assert mse.numpy() > 10 * coarse_mse.numpy()


def test_tex_loss():
"""Test custom TemporalExtremesLoss function that looks at min/max values
in the timeseries."""
loss_obj = TemporalExtremesLoss()

x = np.zeros((1, 1, 1, 72, 1))
y = np.zeros((1, 1, 1, 72, 1))

# loss should be dominated by special min/max values
x[..., 24, 0] = 20
y[..., 25, 0] = 25
loss = loss_obj(x, y)
assert loss.numpy() > 1.5

# loss should be dominated by special min/max values
x[..., 24, 0] = -20
y[..., 25, 0] = -25
loss = loss_obj(x, y)
assert loss.numpy() > 1.5


def test_lr_loss():
"""Test custom LowResLoss that re-coarsens synthetic and true high-res
fields and calculates pointwise loss on the low-res fields"""

# test w/o enhance
t_meth = 'average'
loss_obj = LowResLoss(s_enhance=1, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr = np.random.uniform(-1, 1, (3, 10, 10, 48, 2))
yarr = np.random.uniform(-1, 1, (3, 10, 10, 48, 2))
xtensor = tf.convert_to_tensor(xarr)
ytensor = tf.convert_to_tensor(yarr)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xtensor, ytensor))

# test 5D with s_enhance
s_enhance = 5
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 5D with s/t enhance
s_enhance = 5
t_enhance = 12
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=t_enhance,
t_method=t_meth, tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
xarr_lr = temporal_coarsening(xarr_lr, t_enhance=t_enhance, method=t_meth)
yarr_lr = temporal_coarsening(yarr_lr, t_enhance=t_enhance, method=t_meth)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 5D with subsample
t_meth = 'subsample'
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=t_enhance,
t_method=t_meth, tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
xarr_lr = temporal_coarsening(xarr_lr, t_enhance=t_enhance, method=t_meth)
yarr_lr = temporal_coarsening(yarr_lr, t_enhance=t_enhance, method=t_meth)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

# test 4D spatial only
xarr = np.random.uniform(-1, 1, (3, 10, 10, 2))
yarr = np.random.uniform(-1, 1, (3, 10, 10, 2))
xtensor = tf.convert_to_tensor(xarr)
ytensor = tf.convert_to_tensor(yarr)
s_enhance = 5
loss_obj = LowResLoss(s_enhance=s_enhance, t_enhance=1, t_method=t_meth,
tf_loss='MeanSquaredError')
xarr_lr = spatial_coarsening(xarr, s_enhance=s_enhance, obs_axis=True)
yarr_lr = spatial_coarsening(yarr, s_enhance=s_enhance, obs_axis=True)
loss = loss_obj(xtensor, ytensor)
assert np.allclose(loss, loss_obj._tf_loss(xarr_lr, yarr_lr))

0 comments on commit b8f6d42

Please sign in to comment.