From 24619ff885b1b8e80c7c86fe93544ba1a1618b3d Mon Sep 17 00:00:00 2001 From: Gui Castelao Date: Fri, 10 May 2024 10:46:12 -0600 Subject: [PATCH] Initializing test with the local case --- tests/bias/test_qdm_bias_correction.py | 105 +++++++++++++++++++++---- 1 file changed, 89 insertions(+), 16 deletions(-) diff --git a/tests/bias/test_qdm_bias_correction.py b/tests/bias/test_qdm_bias_correction.py index 33a539624..8dedbb113 100644 --- a/tests/bias/test_qdm_bias_correction.py +++ b/tests/bias/test_qdm_bias_correction.py @@ -8,10 +8,12 @@ import pytest import xarray as xr -from sup3r import TEST_DATA_DIR +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import Sup3rGan +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.bias.bias_calc import QuantileDeltaMappingCorrection from sup3r.bias.bias_transforms import local_qdm_bc -from sup3r.preprocessing.data_handling import DataHandlerNC +from sup3r.preprocessing.data_handling import DataHandlerNC, DataHandlerNCforCC FP_NSRDB = os.path.join(TEST_DATA_DIR, "test_nsrdb_co_2018.h5") FP_CC = os.path.join(TEST_DATA_DIR, "rsds_test.nc") @@ -361,18 +363,89 @@ def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params): assert np.allclose(corrected[idx], original[idx]) -def test_fwd_integration(): - +def test_fwp_integration(tmp_path, fp_fut_cc): + """Test the integration of the bias correction method into the forward pass + framework""" + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + features = ['U_100m', 'V_100m'] + target = (13.67, 125.0) + shape = (8, 8) + temporal_slice = slice(None, None, 1) + fwp_chunk_shape = (4, 4, 150) + input_files = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'), + os.path.join(TEST_DATA_DIR, 'va_test.nc'), + os.path.join(TEST_DATA_DIR, 'orog_test.nc'), + os.path.join(TEST_DATA_DIR, 'zg_test.nc')] + + lat_lon = DataHandlerNCforCC(input_files, features=[], target=target, + shape=shape, + worker_kwargs={'max_workers': 1}).lat_lon + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) + model.meta['lr_features'] = features + model.meta['hr_out_features'] = features + model.meta['s_enhance'] = 3 + model.meta['t_enhance'] = 4 + + bias_fp = os.path.join(tmp_path, 'bc.h5') + out_dir = os.path.join(tmp_path, 'st_gan') + model.save(out_dir) + + scalar = np.random.uniform(0.5, 1, (8, 8, 1)) + adder = np.random.uniform(0, 1, (8, 8, 1)) + + with h5py.File(bias_fp, 'w') as f: + f.create_dataset('U_100m_scalar', data=scalar) + f.create_dataset('U_100m_adder', data=adder) + f.create_dataset('V_100m_scalar', data=scalar) + f.create_dataset('V_100m_adder', data=adder) + f.create_dataset('latitude', data=lat_lon[..., 0]) + f.create_dataset('longitude', data=lat_lon[..., 1]) + + bias_correct_kwargs = {'U_100m': {'feature_name': 'U_100m', + 'bias_fp': bias_fp}, + 'V_100m': {'feature_name': 'V_100m', + 'bias_fp': bias_fp}} + + strat = ForwardPassStrategy( + input_files, + model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=0, temporal_pad=0, + input_handler_kwargs=dict(target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1)), + out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'), + worker_kwargs=dict(max_workers=1), + input_handler='DataHandlerNCforCC') bc_strat = ForwardPassStrategy( - input_files, - model_kwargs={'model_dir': out_dir}, - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=0, temporal_pad=0, - input_handler_kwargs=dict(target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1)), - out_pattern=os.path.join(td, 'out_{file_id}.nc'), - worker_kwargs=dict(max_workers=1), - input_handler='DataHandlerNCforCC', - bias_correct_method='local_linear_bc', - bias_correct_kwargs=bias_correct_kwargs) + input_files, + model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=0, temporal_pad=0, + input_handler_kwargs=dict(target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1)), + out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'), + worker_kwargs=dict(max_workers=1), + input_handler='DataHandlerNCforCC', + bias_correct_method='local_linear_bc', + bias_correct_kwargs=bias_correct_kwargs) + + for ichunk in range(strat.chunks): + + fwp = ForwardPass(strat, chunk_index=ichunk) + bc_fwp = ForwardPass(bc_strat, chunk_index=ichunk) + + i_scalar = np.expand_dims(scalar, axis=-1) + i_adder = np.expand_dims(adder, axis=-1) + i_scalar = i_scalar[bc_fwp.lr_padded_slice[0], + bc_fwp.lr_padded_slice[1]] + i_adder = i_adder[bc_fwp.lr_padded_slice[0], + bc_fwp.lr_padded_slice[1]] + truth = fwp.input_data * i_scalar + i_adder + + assert np.allclose(bc_fwp.input_data, truth, equal_nan=True)