Skip to content

Commit

Permalink
Initializing test with the local case
Browse files Browse the repository at this point in the history
  • Loading branch information
castelao committed May 13, 2024
1 parent f7346e4 commit 24619ff
Showing 1 changed file with 89 additions and 16 deletions.
105 changes: 89 additions & 16 deletions tests/bias/test_qdm_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import xarray as xr

from sup3r import TEST_DATA_DIR
from sup3r import CONFIG_DIR, TEST_DATA_DIR
from sup3r.models import Sup3rGan
from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy
from sup3r.bias.bias_calc import QuantileDeltaMappingCorrection
from sup3r.bias.bias_transforms import local_qdm_bc
from sup3r.preprocessing.data_handling import DataHandlerNC
from sup3r.preprocessing.data_handling import DataHandlerNC, DataHandlerNCforCC

FP_NSRDB = os.path.join(TEST_DATA_DIR, "test_nsrdb_co_2018.h5")
FP_CC = os.path.join(TEST_DATA_DIR, "rsds_test.nc")
Expand Down Expand Up @@ -361,18 +363,89 @@ def test_bc_trend_same_hist(tmp_path, fp_fut_cc, dist_params):
assert np.allclose(corrected[idx], original[idx])


def test_fwd_integration():

def test_fwp_integration(tmp_path, fp_fut_cc):
"""Test the integration of the bias correction method into the forward pass
framework"""
fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json')
fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json')
features = ['U_100m', 'V_100m']
target = (13.67, 125.0)
shape = (8, 8)
temporal_slice = slice(None, None, 1)
fwp_chunk_shape = (4, 4, 150)
input_files = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'),
os.path.join(TEST_DATA_DIR, 'va_test.nc'),
os.path.join(TEST_DATA_DIR, 'orog_test.nc'),
os.path.join(TEST_DATA_DIR, 'zg_test.nc')]

lat_lon = DataHandlerNCforCC(input_files, features=[], target=target,
shape=shape,
worker_kwargs={'max_workers': 1}).lat_lon

Sup3rGan.seed()
model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4)
_ = model.generate(np.ones((4, 10, 10, 6, len(features))))
model.meta['lr_features'] = features
model.meta['hr_out_features'] = features
model.meta['s_enhance'] = 3
model.meta['t_enhance'] = 4

bias_fp = os.path.join(tmp_path, 'bc.h5')
out_dir = os.path.join(tmp_path, 'st_gan')
model.save(out_dir)

scalar = np.random.uniform(0.5, 1, (8, 8, 1))
adder = np.random.uniform(0, 1, (8, 8, 1))

with h5py.File(bias_fp, 'w') as f:
f.create_dataset('U_100m_scalar', data=scalar)
f.create_dataset('U_100m_adder', data=adder)
f.create_dataset('V_100m_scalar', data=scalar)
f.create_dataset('V_100m_adder', data=adder)
f.create_dataset('latitude', data=lat_lon[..., 0])
f.create_dataset('longitude', data=lat_lon[..., 1])

bias_correct_kwargs = {'U_100m': {'feature_name': 'U_100m',
'bias_fp': bias_fp},
'V_100m': {'feature_name': 'V_100m',
'bias_fp': bias_fp}}

strat = ForwardPassStrategy(
input_files,
model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=0, temporal_pad=0,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=1)),
out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'),
worker_kwargs=dict(max_workers=1),
input_handler='DataHandlerNCforCC')
bc_strat = ForwardPassStrategy(
input_files,
model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=0, temporal_pad=0,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=1)),
out_pattern=os.path.join(td, 'out_{file_id}.nc'),
worker_kwargs=dict(max_workers=1),
input_handler='DataHandlerNCforCC',
bias_correct_method='local_linear_bc',
bias_correct_kwargs=bias_correct_kwargs)
input_files,
model_kwargs={'model_dir': out_dir},
fwp_chunk_shape=fwp_chunk_shape,
spatial_pad=0, temporal_pad=0,
input_handler_kwargs=dict(target=target, shape=shape,
temporal_slice=temporal_slice,
worker_kwargs=dict(max_workers=1)),
out_pattern=os.path.join(tmp_path, 'out_{file_id}.nc'),
worker_kwargs=dict(max_workers=1),
input_handler='DataHandlerNCforCC',
bias_correct_method='local_linear_bc',
bias_correct_kwargs=bias_correct_kwargs)

for ichunk in range(strat.chunks):

fwp = ForwardPass(strat, chunk_index=ichunk)
bc_fwp = ForwardPass(bc_strat, chunk_index=ichunk)

i_scalar = np.expand_dims(scalar, axis=-1)
i_adder = np.expand_dims(adder, axis=-1)
i_scalar = i_scalar[bc_fwp.lr_padded_slice[0],
bc_fwp.lr_padded_slice[1]]
i_adder = i_adder[bc_fwp.lr_padded_slice[0],
bc_fwp.lr_padded_slice[1]]
truth = fwp.input_data * i_scalar + i_adder

assert np.allclose(bc_fwp.input_data, truth, equal_nan=True)

0 comments on commit 24619ff

Please sign in to comment.