From ce813c877b5b6db0f129cc1f331bb876cc1eba59 Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 22 Mar 2023 17:14:32 -0600 Subject: [PATCH 01/13] add end time padding --- .../conditional_moment_batch_handling.py | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 26c72e725..00bf0c829 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -98,7 +98,9 @@ def make_output(low_res, high_res, # pylint: disable=E1130 @staticmethod - def make_mask(high_res, s_padding=None, t_padding=None): + def make_mask(high_res, + s_padding=None, t_padding=None, + end_t_padding=False, t_enhance=None): """Make mask for output. The mask is used to ensure consistency when training conditional moments. @@ -128,6 +130,12 @@ def make_mask(high_res, s_padding=None, t_padding=None): t_padding : int | None Temporal padding size. If None or 0, no padding is applied. None by default + end_t_padding : bool | False + Pad end of temporal space. + False by default + t_enhance : int | None + Temporal enhancement factor to define end padding. + None by default model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output output_features_ind : list | np.ndarray | None @@ -139,6 +147,11 @@ def make_mask(high_res, s_padding=None, t_padding=None): t_min = t_padding if t_padding is not None else 0 s_max = -s_padding if s_min > 0 else None t_max = -t_padding if t_min > 0 else None + if end_t_padding and t_enhance > 1: + if t_max is None: + t_max = -(t_enhance - 1) + else: + t_max = -(t_enhance - 1) - t_padding if len(high_res.shape) == 4: mask[:, s_min:s_max, s_min:s_max, :] = 1.0 @@ -159,7 +172,8 @@ def get_coarse_batch(cls, high_res, smoothing_ignore=None, model_mom1=None, s_padding=None, - t_padding=None): + t_padding=None, + end_t_padding=False): """Coarsen high res data and return Batch with high res and low res data @@ -202,6 +216,9 @@ def get_coarse_batch(cls, high_res, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Pad end of temporal space + False by default Returns ------- @@ -227,7 +244,7 @@ def get_coarse_batch(cls, high_res, s_enhance, t_enhance, model_mom1, output_features_ind) mask = cls.make_mask(high_res, - s_padding, t_padding) + s_padding, t_padding, end_t_padding, t_enhance) batch = cls(low_res, high_res, output, mask) return batch @@ -435,7 +452,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, output_features=None, smoothing=None, smoothing_ignore=None, model_mom1=None, - s_padding=None, t_padding=None): + s_padding=None, t_padding=None, end_t_padding=False): """ Parameters ---------- @@ -477,6 +494,9 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Pad end of temporal space + False by default """ handler_shapes = np.array([d.sample_shape for d in data_handlers]) @@ -492,6 +512,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.t_enhance = t_enhance self.s_padding = s_padding self.t_padding = t_padding + self.end_t_padding = end_t_padding self._remaining_observations = len(self.val_indices) self.temporal_coarsening_method = temporal_coarsening_method self._i = 0 @@ -525,7 +546,8 @@ def batch_next(self, high_res): output_features=self.output_features, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) class BatchHandlerMom1(BatchHandler): @@ -542,7 +564,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, means_file=None, overwrite_stats=False, smoothing=None, smoothing_ignore=None, stats_workers=None, norm_workers=None, load_workers=None, max_workers=None, model_mom1=None, - s_padding=None, t_padding=None): + s_padding=None, t_padding=None, end_t_padding=False): """ Parameters ---------- @@ -613,6 +635,9 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Pad end of temporal space + False by default """ if max_workers is not None: norm_workers = stats_workers = load_workers = max_workers @@ -632,6 +657,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.t_enhance = t_enhance self.s_padding = s_padding self.t_padding = t_padding + self.end_t_padding = end_t_padding self.sample_shape = handler_shapes[0] self.means = means self.stds = stds @@ -677,7 +703,8 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') @@ -714,7 +741,8 @@ def __next__(self): smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) self._i += 1 return batch @@ -744,7 +772,8 @@ def __next__(self): smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) self._i += 1 return batch From 50346007f70a28818d43efd2a7da4497c51b129d Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 22 Mar 2023 17:28:08 -0600 Subject: [PATCH 02/13] add end_t_padding to tests --- tests/test_out_conditional_moments.py | 24 +++++++++++---- tests/test_train_conditional_moments.py | 41 +++++++++++++++++-------- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/tests/test_out_conditional_moments.py b/tests/test_out_conditional_moments.py index ae05afbee..4566dd19c 100644 --- a/tests/test_out_conditional_moments.py +++ b/tests/test_out_conditional_moments.py @@ -721,6 +721,7 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None): """Test basic spatiotemporal model outputing for first conditional moment.""" @@ -735,7 +736,8 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Model if model_dir is None: @@ -820,6 +822,7 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None): """Test basic spatiotemporal model outputing for first conditional moment of subfilter velocity.""" @@ -834,7 +837,8 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Model if model_dir is None: @@ -926,6 +930,7 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing @@ -952,7 +957,8 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), s_enhance=s_enhance, t_enhance=t_enhance, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: @@ -1048,6 +1054,7 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1074,7 +1081,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), s_enhance=s_enhance, t_enhance=t_enhance, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: @@ -1183,6 +1191,7 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing @@ -1208,7 +1217,8 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: @@ -1315,6 +1325,7 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1340,7 +1351,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: diff --git a/tests/test_train_conditional_moments.py b/tests/test_train_conditional_moments.py index ecc789deb..5bdfa3171 100644 --- a/tests/test_train_conditional_moments.py +++ b/tests/test_train_conditional_moments.py @@ -400,9 +400,11 @@ def test_train_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, model_mom2.save(out_dir) -@pytest.mark.parametrize('FEATURES', - (['U_100m', 'V_100m'],)) +@pytest.mark.parametrize('FEATURES, end_t_padding', + [(['U_100m', 'V_100m'], False), + (['U_100m', 'V_100m'], True)]) def test_train_st_mom1(FEATURES, + end_t_padding, log=False, full_shape=(20, 20), sample_shape=(12, 12, 24), n_epoch=2, batch_size=2, n_batches=2, @@ -428,7 +430,8 @@ def test_train_st_mom1(FEATURES, batch_handler = BatchHandlerMom1([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -445,9 +448,11 @@ def test_train_st_mom1(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom1_sf(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 24), n_epoch=2, batch_size=2, n_batches=2, + temporal_slice=slice(None, None, 1), out_dir_root=None): """Test basic spatiotemporal model training for first conditional moment of the subfilter velocity.""" @@ -464,13 +469,14 @@ def test_train_st_mom1_sf(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) batch_handler = BatchHandlerMom1SF([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -487,9 +493,11 @@ def test_train_st_mom1_sf(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, batch_size=2, n_batches=2, + temporal_slice=slice(None, None, 1), out_dir_root=None, model_mom1_dir=None): """Test basic spatiotemporal model training @@ -516,14 +524,15 @@ def test_train_st_mom2(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) batch_handler = BatchHandlerMom2([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -539,8 +548,10 @@ def test_train_st_mom2(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sf(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, + temporal_slice=slice(None, None, 1), batch_size=2, n_batches=2, out_dir_root=None, model_mom1_dir=None): @@ -568,14 +579,15 @@ def test_train_st_mom2_sf(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) batch_handler = BatchHandlerMom2SF([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -591,8 +603,10 @@ def test_train_st_mom2_sf(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sep(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, + temporal_slice=slice(None, None, 1), batch_size=2, n_batches=2, out_dir_root=None): """Test basic spatiotemporal model training @@ -610,7 +624,7 @@ def test_train_st_mom2_sep(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) @@ -618,7 +632,8 @@ def test_train_st_mom2_sep(FEATURES, batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -634,6 +649,7 @@ def test_train_st_mom2_sep(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sep_sf(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, batch_size=2, n_batches=2, @@ -659,7 +675,8 @@ def test_train_st_mom2_sep_sf(FEATURES, batch_handler = BatchHandlerMom2SepSF([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: From 5364ee37febaf3782dd6ee01e057cf2550c28985 Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 22 Mar 2023 22:09:59 -0600 Subject: [PATCH 03/13] fix 0th order temporal upsampling --- sup3r/utilities/utilities.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 69046c0f9..0b48a8f3e 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -9,9 +9,10 @@ import glob from scipy import ndimage as nd from scipy.interpolate import RegularGridInterpolator -from scipy.ndimage.filters import gaussian_filter from scipy.interpolate import interp1d from scipy.ndimage import interpolation +from scipy.ndimage import zoom +from ndimage.filters import gaussian_filter from fnmatch import fnmatch import os import re @@ -632,9 +633,11 @@ def temporal_simple_enhancing(data, t_enhance=4): enhanced_data = data elif t_enhance not in [None, 1] and len(data.shape) == 5: enhancement = [1, 1, 1, t_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif len(data.shape) != 5: msg = ('Data must be 5D to do temporal enhancing, but ' f'received: {data.shape}') From 291d68cdc64bdc72684706a64119429f5cd779b5 Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 22 Mar 2023 22:11:44 -0600 Subject: [PATCH 04/13] fix 0th order temporal upsampling --- sup3r/utilities/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 0b48a8f3e..c09533116 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -12,7 +12,7 @@ from scipy.interpolate import interp1d from scipy.ndimage import interpolation from scipy.ndimage import zoom -from ndimage.filters import gaussian_filter +from scipy.ndimage.filters import gaussian_filter from fnmatch import fnmatch import os import re From c268e3281fa3c4a7a5f9a54a5cdf940c390e40e0 Mon Sep 17 00:00:00 2001 From: Malik Date: Wed, 22 Mar 2023 22:26:38 -0600 Subject: [PATCH 05/13] fix spatial enhancing utility --- sup3r/utilities/utilities.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index c09533116..cf456a88c 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -10,7 +10,6 @@ from scipy import ndimage as nd from scipy.interpolate import RegularGridInterpolator from scipy.interpolate import interp1d -from scipy.ndimage import interpolation from scipy.ndimage import zoom from scipy.ndimage.filters import gaussian_filter from fnmatch import fnmatch @@ -835,27 +834,35 @@ def spatial_simple_enhancing(data, s_enhance=2, obs_axis=True): if obs_axis and len(data.shape) == 5: enhancement = [1, s_enhance, s_enhance, 1, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif obs_axis and len(data.shape) == 4: enhancement = [1, s_enhance, s_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif not obs_axis and len(data.shape) == 4: enhancement = [s_enhance, s_enhance, 1, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif not obs_axis and len(data.shape) == 3: enhancement = [s_enhance, s_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) else: msg = ('Data must be 3D, 4D, or 5D to do spatial enhancing, but ' f'received: {data.shape}') From 0e15dcc93d6563166fb764d6509922bd2c0629f3 Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:30:01 -0600 Subject: [PATCH 06/13] add linear time interp option --- sup3r/utilities/utilities.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index cf456a88c..769c31fd6 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -611,7 +611,7 @@ def temporal_coarsening(data, t_enhance=4, method='subsample'): return coarse_data -def temporal_simple_enhancing(data, t_enhance=4): +def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): """"Upsample data according to t_enhance resolution Parameters @@ -631,12 +631,21 @@ def temporal_simple_enhancing(data, t_enhance=4): if t_enhance in [None, 1]: enhanced_data = data elif t_enhance not in [None, 1] and len(data.shape) == 5: - enhancement = [1, 1, 1, t_enhance, 1] - enhanced_data = zoom(data, - enhancement, - order=0, - mode='nearest', - grid_mode=True) + if mode == 'constant': + enhancement = [1, 1, 1, t_enhance, 1] + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) + elif mode == 'linear': + index_t_hr = np.array(list(range(data.shape[3] * t_enhance))) + index_t_lr = index_t_hr[::t_enhance] + enhanced_data = interp1d(index_t_lr, + data, + axis=3, + fill_value='extrapolate')(index_t_hr) + enhanced_data = np.array(enhanced_data, dtype=np.float32) elif len(data.shape) != 5: msg = ('Data must be 5D to do temporal enhancing, but ' f'received: {data.shape}') From e9eb1df5ffd56182199ff99be41b92534a1e46ff Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:32:09 -0600 Subject: [PATCH 07/13] enable choosing temporal enhancement mode from BatchHandler --- .../conditional_moment_batch_handling.py | 72 +++++++++++++++---- 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 00bf0c829..19faaaa4d 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -71,7 +71,8 @@ def mask(self): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -93,6 +94,9 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ return high_res @@ -165,6 +169,7 @@ def make_mask(high_res, def get_coarse_batch(cls, high_res, s_enhance, t_enhance=1, temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', output_features_ind=None, output_features=None, training_features=None, @@ -192,6 +197,9 @@ def get_coarse_batch(cls, high_res, temporal_coarsening_method : str Method to use for temporal coarsening. Can be subsample, average, or total + temporal_enhancing_method : str + Method to enhance temporally when constructin subfilter. + Can be constant or linear output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -242,7 +250,8 @@ def get_coarse_batch(cls, high_res, high_res = cls.reduce_features(high_res, output_features_ind) output = cls.make_output(low_res, high_res, s_enhance, t_enhance, - model_mom1, output_features_ind) + model_mom1, output_features_ind, + temporal_enhancing_method) mask = cls.make_mask(high_res, s_padding, t_padding, end_t_padding, t_enhance) batch = cls(low_res, high_res, output, mask) @@ -257,7 +266,8 @@ class BatchMom1SF(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -279,12 +289,16 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove LR from HR enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) return high_res - enhanced_lr @@ -296,7 +310,8 @@ class BatchMom2(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -318,6 +333,9 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() @@ -331,7 +349,8 @@ class BatchMom2Sep(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -353,12 +372,16 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ return super(BatchMom2Sep, BatchMom2Sep).make_output(low_res, high_res, s_enhance, t_enhance, model_mom1, - output_features_ind)**2 + output_features_ind, + t_enhance_mode)**2 class BatchMom2SF(BatchMom1): @@ -368,7 +391,8 @@ class BatchMom2SF(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -390,13 +414,17 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either 'constant' or 'linear' """ # Remove LR and first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) return (high_res - enhanced_lr - out)**2 @@ -409,7 +437,8 @@ class BatchMom2SepSF(BatchMom1SF): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -431,13 +460,17 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove LR from HR and square it return super(BatchMom2SepSF, BatchMom2SepSF).make_output(low_res, high_res, s_enhance, t_enhance, model_mom1, - output_features_ind)**2 + output_features_ind, + t_enhance_mode)**2 class ValidationDataMom1(ValidationData): @@ -448,6 +481,7 @@ class ValidationDataMom1(ValidationData): def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', output_features_ind=None, output_features=None, smoothing=None, smoothing_ignore=None, @@ -471,6 +505,10 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps + temporal_enhancing_method : str + [constant, linear] + Constant will repeat the steps t_enhance times. Linear will + linearly interpolate between coarse steps. output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -515,6 +553,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.end_t_padding = end_t_padding self._remaining_observations = len(self.val_indices) self.temporal_coarsening_method = temporal_coarsening_method + self.temporal_enhancing_method = temporal_enhancing_method self._i = 0 self.output_features_ind = output_features_ind self.output_features = output_features @@ -540,6 +579,7 @@ def batch_next(self, high_res): high_res, self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, + temporal_enhancing_method=self.temporal_enhancing_method, output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, @@ -560,7 +600,8 @@ class BatchHandlerMom1(BatchHandler): def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, means=None, stds=None, norm=True, n_batches=10, - temporal_coarsening_method='subsample', stdevs_file=None, + temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', stdevs_file=None, means_file=None, overwrite_stats=False, smoothing=None, smoothing_ignore=None, stats_workers=None, norm_workers=None, load_workers=None, max_workers=None, model_mom1=None, @@ -598,6 +639,10 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps + temporal_enhancing_method : str + [constant, linear] + Constant will repeat the steps t_enhance times. Linear will + linearly interpolate between coarse steps. stdevs_file : str | None Path to stdevs data or where to save data after calling get_stats means_file : str | None @@ -663,6 +708,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.stds = stds self.n_batches = n_batches self.temporal_coarsening_method = temporal_coarsening_method + self.temporal_enhancing_method = temporal_enhancing_method self.current_batch_indices = None self.current_handler_index = None self.stdevs_file = stdevs_file @@ -697,6 +743,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, data_handlers, batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, + temporal_enhancing_method=temporal_enhancing_method, output_features_ind=self.output_features_ind, output_features=self.output_features, smoothing=self.smoothing, @@ -734,6 +781,7 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, + temporal_enhancing_method=self.temporal_enhancing_method, output_features_ind=self.output_features_ind, output_features=self.output_features, training_features=self.training_features, From df09c7d60a83b049e3e47e9769ae0a69c6a0452c Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:36:29 -0600 Subject: [PATCH 08/13] add test of temporal enhancing method --- tests/test_train_conditional_moments.py | 42 +++++++++++++++---------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/tests/test_train_conditional_moments.py b/tests/test_train_conditional_moments.py index 5bdfa3171..927da3753 100644 --- a/tests/test_train_conditional_moments.py +++ b/tests/test_train_conditional_moments.py @@ -445,9 +445,11 @@ def test_train_st_mom1(FEATURES, model.save(out_dir) -@pytest.mark.parametrize('FEATURES', - (['U_100m', 'V_100m'],)) +@pytest.mark.parametrize('FEATURES, t_enhance_mode', + [(['U_100m', 'V_100m'], 'constant'), + (['U_100m', 'V_100m'], 'linear')]) def test_train_st_mom1_sf(FEATURES, + t_enhance_mode, end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 24), n_epoch=2, @@ -473,10 +475,12 @@ def test_train_st_mom1_sf(FEATURES, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom1SF([handler], batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom1SF( + [handler], batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -548,6 +552,7 @@ def test_train_st_mom2(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sf(FEATURES, + t_enhance_mode='constant', end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, @@ -583,11 +588,13 @@ def test_train_st_mom2_sf(FEATURES, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom2SF([handler], batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches, - model_mom1=model_mom1, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom2SF( + [handler], batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + model_mom1=model_mom1, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -649,6 +656,7 @@ def test_train_st_mom2_sep(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sep_sf(FEATURES, + t_enhance_mode='constant', end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, @@ -672,11 +680,13 @@ def test_train_st_mom2_sep_sf(FEATURES, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom2SepSF([handler], - batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom2SepSF( + [handler], + batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: From afd20174cd2fc4d1a7515b8243bcb2de62d28f6e Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:50:10 -0600 Subject: [PATCH 09/13] escape C901 --- .flake8 | 3 +- .github/linters/.flake8 | 1 + tests/test_out_conditional_moments.py | 106 +++++++++++++++++--------- 3 files changed, 74 insertions(+), 36 deletions(-) diff --git a/.flake8 b/.flake8 index 73bd375cd..94150c54d 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] -ignore = E731,E402,F,W503 +ignore = E731,E402,F,W503,C901 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist max-complexity = 10 +per-file-ignores = tests/*.py:C901 diff --git a/.github/linters/.flake8 b/.github/linters/.flake8 index ce856e587..f82a3fa83 100644 --- a/.github/linters/.flake8 +++ b/.github/linters/.flake8 @@ -2,3 +2,4 @@ ignore = E731,E402,F,W503 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bin/tmp/* max-complexity = 10 +per-file-ignores = tests/*.py:C901 diff --git a/tests/test_out_conditional_moments.py b/tests/test_out_conditional_moments.py index 4566dd19c..229b2e450 100644 --- a/tests/test_out_conditional_moments.py +++ b/tests/test_out_conditional_moments.py @@ -788,13 +788,17 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), + batch_handler.means[0]) aug_lr = np.reshape(lr, (1,) + lr.shape + (1,)) tup_lr = temporal_simple_enhancing(aug_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = tup_lr[0, :, :, :, 0] hr = (batch.output[i, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) gen = (out[i, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], gen[:, :, j]], [0, batch.output.shape[1]], @@ -823,6 +827,7 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, end_t_padding=False, + t_enhance_mode='constant', model_dir=None): """Test basic spatiotemporal model outputing for first conditional moment of subfilter velocity.""" @@ -833,12 +838,14 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), val_split=0, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom1SF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom1SF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Model if model_dir is None: @@ -870,14 +877,16 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -894,8 +903,10 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), * batch_handler.stds[0] + batch_handler.means[0] + sf_pred) - - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_pred[:, :, j], sf[:, :, j], @@ -993,7 +1004,8 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1006,7 +1018,10 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), * batch_handler.stds[0]**2) integratedSigma.append(np.mean(sigma, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_to_mean[:, :, j], sigma[:, :, j]], @@ -1055,6 +1070,7 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, end_t_padding=False, + t_enhance_mode='constant', model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1076,13 +1092,15 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), fp_gen = os.path.join(model_mom1_dir, 'model_params.json') model_mom1 = Sup3rCondMom(fp_gen).load(model_mom1_dir) - batch_handler = BatchHandlerMom2SF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches, - model_mom1=model_mom1, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom2SF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + model_mom1=model_mom1, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Mom2 Model if model_dir is None: @@ -1116,7 +1134,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1124,7 +1143,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -1139,7 +1159,11 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), sigma = np.sqrt(out[i, :, :, :, 0] * batch_handler.stds[0]**2) integratedSigma.append(np.mean(sigma, axis=(0, 1))) - for j in range(batch.output.shape[3]): + + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], sf[:, :, j], sf_to_mean[:, :, j], @@ -1256,7 +1280,8 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1277,7 +1302,10 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), a_min=0, a_max=None)) integratedSigma.append(np.mean(sigma_pred, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_to_mean[:, :, j], sigma_pred[:, :, j]], @@ -1326,6 +1354,7 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, end_t_padding=False, + t_enhance_mode='constant', model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1347,12 +1376,14 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), fp_gen = os.path.join(model_mom1_dir, 'model_params.json') model_mom1 = Sup3rCondMom(fp_gen).load(model_mom1_dir) - batch_handler = BatchHandlerMom2SepSF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches, - end_t_padding=end_t_padding) + batch_handler = BatchHandlerMom2SepSF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Mom2 Model if model_dir is None: @@ -1390,7 +1421,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1398,7 +1430,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -1419,7 +1452,10 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), a_min=0, a_max=None)) integratedSigma.append(np.mean(sigma_pred, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], sf[:, :, j], sf_to_mean[:, :, j], From 31591b68202244a4d1d2703d3cffddf306e6f166 Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:54:40 -0600 Subject: [PATCH 10/13] fix lint warning escape --- .flake8 | 3 ++- .github/linters/.flake8 | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.flake8 b/.flake8 index 94150c54d..dc4d128a4 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,5 @@ ignore = E731,E402,F,W503,C901 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist max-complexity = 10 -per-file-ignores = tests/*.py:C901 +per-file-ignores = + tests/*.py:C901 diff --git a/.github/linters/.flake8 b/.github/linters/.flake8 index f82a3fa83..5d6ea8b96 100644 --- a/.github/linters/.flake8 +++ b/.github/linters/.flake8 @@ -2,4 +2,5 @@ ignore = E731,E402,F,W503 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bin/tmp/* max-complexity = 10 -per-file-ignores = tests/*.py:C901 +per-file-ignores = + tests/*.py:C901 From 302c9c3f6d1968fcc828a2cc8b9e965022754423 Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 23 Mar 2023 00:55:30 -0600 Subject: [PATCH 11/13] fix lint warning escape --- .flake8 | 4 +--- .github/linters/.flake8 | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.flake8 b/.flake8 index dc4d128a4..d4972524a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,4 @@ [flake8] ignore = E731,E402,F,W503,C901 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist -max-complexity = 10 -per-file-ignores = - tests/*.py:C901 +max-complexity = 12 diff --git a/.github/linters/.flake8 b/.github/linters/.flake8 index 5d6ea8b96..53e1a8eea 100644 --- a/.github/linters/.flake8 +++ b/.github/linters/.flake8 @@ -1,6 +1,4 @@ [flake8] ignore = E731,E402,F,W503 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bin/tmp/* -max-complexity = 10 -per-file-ignores = - tests/*.py:C901 +max-complexity = 12 From 89a390fef75653e679f2a70129829fde96cbb538 Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 13 Apr 2023 14:28:29 -0600 Subject: [PATCH 12/13] expand description of end t padding --- .../conditional_moment_batch_handling.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 19faaaa4d..5d892bde2 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -135,7 +135,9 @@ def make_mask(high_res, Temporal padding size. If None or 0, no padding is applied. None by default end_t_padding : bool | False - Pad end of temporal space. + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. False by default t_enhance : int | None Temporal enhancement factor to define end padding. @@ -225,7 +227,9 @@ def get_coarse_batch(cls, high_res, Width of temporal padding to predict only middle part. If None, no padding is used end_t_padding : bool | False - Pad end of temporal space + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. False by default Returns @@ -533,7 +537,9 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Width of temporal padding to predict only middle part. If None, no padding is used end_t_padding : bool | False - Pad end of temporal space + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. False by default """ @@ -681,7 +687,9 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Width of temporal padding to predict only middle part. If None, no padding is used end_t_padding : bool | False - Pad end of temporal space + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. False by default """ if max_workers is not None: From 74ad227f9616cc759c792876074850216c64d6f3 Mon Sep 17 00:00:00 2001 From: Malik Date: Thu, 13 Apr 2023 14:33:49 -0600 Subject: [PATCH 13/13] expand description of temporal enhancing method --- .../conditional_moment_batch_handling.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 5d892bde2..08eb5c345 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -200,8 +200,14 @@ def get_coarse_batch(cls, high_res, Method to use for temporal coarsening. Can be subsample, average, or total temporal_enhancing_method : str - Method to enhance temporally when constructin subfilter. - Can be constant or linear + [constant, linear] + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -511,8 +517,13 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, time steps temporal_enhancing_method : str [constant, linear] - Constant will repeat the steps t_enhance times. Linear will - linearly interpolate between coarse steps. + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -647,8 +658,13 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, time steps temporal_enhancing_method : str [constant, linear] - Constant will repeat the steps t_enhance times. Linear will - linearly interpolate between coarse steps. + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. stdevs_file : str | None Path to stdevs data or where to save data after calling get_stats means_file : str | None