Skip to content

Commit

Permalink
added test for no regrid with dual data handler
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Apr 4, 2024
1 parent 685acd3 commit 7dcac89
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tests/data_handling/test_dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,63 @@ def test_normalization(cache,
rtol=rtol, atol=atol)


def test_no_regrid(log=False, full_shape=(20, 20), sample_shape=(10, 10, 4)):
"""Test no regridding of the LR data with correct normalization and
view/slice of the lr dataset"""
if log:
init_logger('sup3r', log_level='DEBUG')

s_enhance = 2
t_enhance = 2

hr_dh = DataHandlerH5(FP_WTK, FEATURES[0], target=TARGET_COORD,
shape=full_shape, sample_shape=sample_shape,
temporal_slice=slice(None, None, 10),
worker_kwargs=dict(max_workers=1),
val_split=0.0)
lr_handler = DataHandlerH5(FP_WTK, FEATURES[1], target=TARGET_COORD,
shape=full_shape,
sample_shape=(sample_shape[0] // s_enhance,
sample_shape[1] // s_enhance,
sample_shape[2] // t_enhance),
temporal_slice=slice(None, -10,
t_enhance * 10),
hr_spatial_coarsen=2, cache_pattern=None,
worker_kwargs=dict(max_workers=1),
val_split=0.0)

hr_dh0 = copy.deepcopy(hr_dh)
hr_dh1 = copy.deepcopy(hr_dh)
lr_dh0 = copy.deepcopy(lr_handler)
lr_dh1 = copy.deepcopy(lr_handler)

ddh0 = DualDataHandler(hr_dh0, lr_dh0, s_enhance=s_enhance,
t_enhance=t_enhance, regrid_lr=True)
ddh1 = DualDataHandler(hr_dh1, lr_dh1, s_enhance=s_enhance,
t_enhance=t_enhance, regrid_lr=False)

_ = DualBatchHandler([ddh0], norm=True)
_ = DualBatchHandler([ddh1], norm=True)

hr_m0 = np.mean(ddh0.hr_data, axis=(0, 1, 2))
lr_m0 = np.mean(ddh0.lr_data, axis=(0, 1, 2))
hr_m1 = np.mean(ddh1.hr_data, axis=(0, 1, 2))
lr_m1 = np.mean(ddh1.lr_data, axis=(0, 1, 2))
assert np.allclose(hr_m0, hr_m1)
assert np.allclose(lr_m0, lr_m1)
assert np.allclose(hr_m0, 0, atol=1e-3)
assert np.allclose(lr_m0, 0, atol=1e-6)

hr_s0 = np.std(ddh0.hr_data, axis=(0, 1, 2))
lr_s0 = np.std(ddh0.lr_data, axis=(0, 1, 2))
hr_s1 = np.std(ddh1.hr_data, axis=(0, 1, 2))
lr_s1 = np.std(ddh1.lr_data, axis=(0, 1, 2))
assert np.allclose(hr_s0, hr_s1)
assert np.allclose(lr_s0, lr_s1)
assert np.allclose(hr_s0, 1, atol=1e-3)
assert np.allclose(lr_s0, 1, atol=1e-6)


@pytest.mark.parametrize(['lr_features', 'hr_features', 'hr_exo_features'],
[(['U_100m'], ['U_100m', 'V_100m'], ['V_100m']),
(['U_100m'], ['U_100m', 'V_100m'], ('V_100m',)),
Expand Down

0 comments on commit 7dcac89

Please sign in to comment.