diff --git a/.flake8 b/.flake8 index 73bd375cdd..d4972524aa 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] -ignore = E731,E402,F,W503 +ignore = E731,E402,F,W503,C901 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist -max-complexity = 10 +max-complexity = 12 diff --git a/.github/linters/.flake8 b/.github/linters/.flake8 index ce856e587a..53e1a8eea5 100644 --- a/.github/linters/.flake8 +++ b/.github/linters/.flake8 @@ -1,4 +1,4 @@ [flake8] ignore = E731,E402,F,W503 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,bin/tmp/* -max-complexity = 10 +max-complexity = 12 diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index 26c72e725a..08eb5c3454 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -71,7 +71,8 @@ def mask(self): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -93,12 +94,17 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ return high_res # pylint: disable=E1130 @staticmethod - def make_mask(high_res, s_padding=None, t_padding=None): + def make_mask(high_res, + s_padding=None, t_padding=None, + end_t_padding=False, t_enhance=None): """Make mask for output. The mask is used to ensure consistency when training conditional moments. @@ -128,6 +134,14 @@ def make_mask(high_res, s_padding=None, t_padding=None): t_padding : int | None Temporal padding size. If None or 0, no padding is applied. None by default + end_t_padding : bool | False + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. + False by default + t_enhance : int | None + Temporal enhancement factor to define end padding. + None by default model_mom1 : Sup3rCondMom | None Model used to modify the make the batch output output_features_ind : list | np.ndarray | None @@ -139,6 +153,11 @@ def make_mask(high_res, s_padding=None, t_padding=None): t_min = t_padding if t_padding is not None else 0 s_max = -s_padding if s_min > 0 else None t_max = -t_padding if t_min > 0 else None + if end_t_padding and t_enhance > 1: + if t_max is None: + t_max = -(t_enhance - 1) + else: + t_max = -(t_enhance - 1) - t_padding if len(high_res.shape) == 4: mask[:, s_min:s_max, s_min:s_max, :] = 1.0 @@ -152,6 +171,7 @@ def make_mask(high_res, s_padding=None, t_padding=None): def get_coarse_batch(cls, high_res, s_enhance, t_enhance=1, temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', output_features_ind=None, output_features=None, training_features=None, @@ -159,7 +179,8 @@ def get_coarse_batch(cls, high_res, smoothing_ignore=None, model_mom1=None, s_padding=None, - t_padding=None): + t_padding=None, + end_t_padding=False): """Coarsen high res data and return Batch with high res and low res data @@ -178,6 +199,15 @@ def get_coarse_batch(cls, high_res, temporal_coarsening_method : str Method to use for temporal coarsening. Can be subsample, average, or total + temporal_enhancing_method : str + [constant, linear] + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -202,6 +232,11 @@ def get_coarse_batch(cls, high_res, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. + False by default Returns ------- @@ -225,9 +260,10 @@ def get_coarse_batch(cls, high_res, high_res = cls.reduce_features(high_res, output_features_ind) output = cls.make_output(low_res, high_res, s_enhance, t_enhance, - model_mom1, output_features_ind) + model_mom1, output_features_ind, + temporal_enhancing_method) mask = cls.make_mask(high_res, - s_padding, t_padding) + s_padding, t_padding, end_t_padding, t_enhance) batch = cls(low_res, high_res, output, mask) return batch @@ -240,7 +276,8 @@ class BatchMom1SF(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -262,12 +299,16 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove LR from HR enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) return high_res - enhanced_lr @@ -279,7 +320,8 @@ class BatchMom2(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -301,6 +343,9 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() @@ -314,7 +359,8 @@ class BatchMom2Sep(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -336,12 +382,16 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ return super(BatchMom2Sep, BatchMom2Sep).make_output(low_res, high_res, s_enhance, t_enhance, model_mom1, - output_features_ind)**2 + output_features_ind, + t_enhance_mode)**2 class BatchMom2SF(BatchMom1): @@ -351,7 +401,8 @@ class BatchMom2SF(BatchMom1): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -373,13 +424,17 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either 'constant' or 'linear' """ # Remove LR and first moment from HR and square it out = model_mom1._tf_generate(low_res).numpy() enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) enhanced_lr = temporal_simple_enhancing(enhanced_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind) return (high_res - enhanced_lr - out)**2 @@ -392,7 +447,8 @@ class BatchMom2SepSF(BatchMom1SF): @staticmethod def make_output(low_res, high_res, s_enhance=None, t_enhance=None, - model_mom1=None, output_features_ind=None): + model_mom1=None, output_features_ind=None, + t_enhance_mode='constant'): """Make custom batch output Parameters @@ -414,13 +470,17 @@ def make_output(low_res, high_res, output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. + t_enhance_mode : str + Enhancing mode for temporal subfilter. + Can be either constant or linear """ # Remove LR from HR and square it return super(BatchMom2SepSF, BatchMom2SepSF).make_output(low_res, high_res, s_enhance, t_enhance, model_mom1, - output_features_ind)**2 + output_features_ind, + t_enhance_mode)**2 class ValidationDataMom1(ValidationData): @@ -431,11 +491,12 @@ class ValidationDataMom1(ValidationData): def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', output_features_ind=None, output_features=None, smoothing=None, smoothing_ignore=None, model_mom1=None, - s_padding=None, t_padding=None): + s_padding=None, t_padding=None, end_t_padding=False): """ Parameters ---------- @@ -454,6 +515,15 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps + temporal_enhancing_method : str + [constant, linear] + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. output_features_ind : list | np.ndarray | None List/array of feature channel indices that are used for generative output, without any feature indices used only for training. @@ -477,6 +547,11 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. + False by default """ handler_shapes = np.array([d.sample_shape for d in data_handlers]) @@ -492,8 +567,10 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.t_enhance = t_enhance self.s_padding = s_padding self.t_padding = t_padding + self.end_t_padding = end_t_padding self._remaining_observations = len(self.val_indices) self.temporal_coarsening_method = temporal_coarsening_method + self.temporal_enhancing_method = temporal_enhancing_method self._i = 0 self.output_features_ind = output_features_ind self.output_features = output_features @@ -519,13 +596,15 @@ def batch_next(self, high_res): high_res, self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, + temporal_enhancing_method=self.temporal_enhancing_method, output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, output_features=self.output_features, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) class BatchHandlerMom1(BatchHandler): @@ -538,11 +617,12 @@ class BatchHandlerMom1(BatchHandler): def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, means=None, stds=None, norm=True, n_batches=10, - temporal_coarsening_method='subsample', stdevs_file=None, + temporal_coarsening_method='subsample', + temporal_enhancing_method='constant', stdevs_file=None, means_file=None, overwrite_stats=False, smoothing=None, smoothing_ignore=None, stats_workers=None, norm_workers=None, load_workers=None, max_workers=None, model_mom1=None, - s_padding=None, t_padding=None): + s_padding=None, t_padding=None, end_t_padding=False): """ Parameters ---------- @@ -576,6 +656,15 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, Subsample will take every t_enhance-th time step, average will average over t_enhance time steps, total will sum over t_enhance time steps + temporal_enhancing_method : str + [constant, linear] + Method to enhance temporally when constructing subfilter. + At every temporal location, a low-res temporal data is substracted + from the high-res temporal data predicted. + constant will assume that the low-res temporal data is constant + between landmarks. + linear will linearly interpolate between landmarks to generate the + low-res data to remove from the high-res. stdevs_file : str | None Path to stdevs data or where to save data after calling get_stats means_file : str | None @@ -613,6 +702,11 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, t_padding : int | None Width of temporal padding to predict only middle part. If None, no padding is used + end_t_padding : bool | False + Zero pad the end of temporal space. + Ensures that loss is calculated only if snapshot is surrounded + by temporal landmarks. + False by default """ if max_workers is not None: norm_workers = stats_workers = load_workers = max_workers @@ -632,11 +726,13 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.t_enhance = t_enhance self.s_padding = s_padding self.t_padding = t_padding + self.end_t_padding = end_t_padding self.sample_shape = handler_shapes[0] self.means = means self.stds = stds self.n_batches = n_batches self.temporal_coarsening_method = temporal_coarsening_method + self.temporal_enhancing_method = temporal_enhancing_method self.current_batch_indices = None self.current_handler_index = None self.stdevs_file = stdevs_file @@ -671,13 +767,15 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, data_handlers, batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, + temporal_enhancing_method=temporal_enhancing_method, output_features_ind=self.output_features_ind, output_features=self.output_features, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') @@ -707,6 +805,7 @@ def __next__(self): batch = self.BATCH_CLASS.get_coarse_batch( high_res, self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, + temporal_enhancing_method=self.temporal_enhancing_method, output_features_ind=self.output_features_ind, output_features=self.output_features, training_features=self.training_features, @@ -714,7 +813,8 @@ def __next__(self): smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) self._i += 1 return batch @@ -744,7 +844,8 @@ def __next__(self): smoothing_ignore=self.smoothing_ignore, model_mom1=self.model_mom1, s_padding=self.s_padding, - t_padding=self.t_padding) + t_padding=self.t_padding, + end_t_padding=self.end_t_padding) self._i += 1 return batch diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index ba84b96d1f..15110e84bd 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -9,8 +9,9 @@ import glob from scipy import ndimage as nd from scipy.interpolate import RegularGridInterpolator +from scipy.interpolate import interp1d +from scipy.ndimage import zoom from scipy.ndimage.filters import gaussian_filter -from scipy.ndimage import interpolation from fnmatch import fnmatch import os import re @@ -634,7 +635,7 @@ def temporal_coarsening(data, t_enhance=4, method='subsample'): return coarse_data -def temporal_simple_enhancing(data, t_enhance=4): +def temporal_simple_enhancing(data, t_enhance=4, mode='constant'): """"Upsample data according to t_enhance resolution Parameters @@ -654,10 +655,21 @@ def temporal_simple_enhancing(data, t_enhance=4): if t_enhance in [None, 1]: enhanced_data = data elif t_enhance not in [None, 1] and len(data.shape) == 5: - enhancement = [1, 1, 1, t_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + if mode == 'constant': + enhancement = [1, 1, 1, t_enhance, 1] + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) + elif mode == 'linear': + index_t_hr = np.array(list(range(data.shape[3] * t_enhance))) + index_t_lr = index_t_hr[::t_enhance] + enhanced_data = interp1d(index_t_lr, + data, + axis=3, + fill_value='extrapolate')(index_t_hr) + enhanced_data = np.array(enhanced_data, dtype=np.float32) elif len(data.shape) != 5: msg = ('Data must be 5D to do temporal enhancing, but ' f'received: {data.shape}') @@ -855,27 +867,35 @@ def spatial_simple_enhancing(data, s_enhance=2, obs_axis=True): if obs_axis and len(data.shape) == 5: enhancement = [1, s_enhance, s_enhance, 1, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif obs_axis and len(data.shape) == 4: enhancement = [1, s_enhance, s_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif not obs_axis and len(data.shape) == 4: enhancement = [s_enhance, s_enhance, 1, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) elif not obs_axis and len(data.shape) == 3: enhancement = [s_enhance, s_enhance, 1] - enhanced_data = interpolation.zoom(data, - enhancement, - order=0) + enhanced_data = zoom(data, + enhancement, + order=0, + mode='nearest', + grid_mode=True) else: msg = ('Data must be 3D, 4D, or 5D to do spatial enhancing, but ' f'received: {data.shape}') diff --git a/tests/forward_pass/test_out_conditional_moments.py b/tests/forward_pass/test_out_conditional_moments.py index ae05afbee5..229b2e450c 100644 --- a/tests/forward_pass/test_out_conditional_moments.py +++ b/tests/forward_pass/test_out_conditional_moments.py @@ -721,6 +721,7 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None): """Test basic spatiotemporal model outputing for first conditional moment.""" @@ -735,7 +736,8 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Model if model_dir is None: @@ -786,13 +788,17 @@ def test_out_st_mom1(plot=False, full_shape=(20, 20), + batch_handler.means[0]) aug_lr = np.reshape(lr, (1,) + lr.shape + (1,)) tup_lr = temporal_simple_enhancing(aug_lr, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = tup_lr[0, :, :, :, 0] hr = (batch.output[i, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) gen = (out[i, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], gen[:, :, j]], [0, batch.output.shape[1]], @@ -820,6 +826,8 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, + t_enhance_mode='constant', model_dir=None): """Test basic spatiotemporal model outputing for first conditional moment of subfilter velocity.""" @@ -830,11 +838,14 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), val_split=0, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom1SF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches) + batch_handler = BatchHandlerMom1SF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Model if model_dir is None: @@ -866,14 +877,16 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -890,8 +903,10 @@ def test_out_st_mom1_sf(plot=False, full_shape=(20, 20), * batch_handler.stds[0] + batch_handler.means[0] + sf_pred) - - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_pred[:, :, j], sf[:, :, j], @@ -926,6 +941,7 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing @@ -952,7 +968,8 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), s_enhance=s_enhance, t_enhance=t_enhance, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: @@ -987,7 +1004,8 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1000,7 +1018,10 @@ def test_out_st_mom2(plot=False, full_shape=(20, 20), * batch_handler.stds[0]**2) integratedSigma.append(np.mean(sigma, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_to_mean[:, :, j], sigma[:, :, j]], @@ -1048,6 +1069,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, + t_enhance_mode='constant', model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1069,12 +1092,15 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), fp_gen = os.path.join(model_mom1_dir, 'model_params.json') model_mom1 = Sup3rCondMom(fp_gen).load(model_mom1_dir) - batch_handler = BatchHandlerMom2SF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches, - model_mom1=model_mom1) + batch_handler = BatchHandlerMom2SF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + model_mom1=model_mom1, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Mom2 Model if model_dir is None: @@ -1108,7 +1134,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1116,7 +1143,8 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -1131,7 +1159,11 @@ def test_out_st_mom2_sf(plot=False, full_shape=(20, 20), sigma = np.sqrt(out[i, :, :, :, 0] * batch_handler.stds[0]**2) integratedSigma.append(np.mean(sigma, axis=(0, 1))) - for j in range(batch.output.shape[3]): + + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], sf[:, :, j], sf_to_mean[:, :, j], @@ -1183,6 +1215,7 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing @@ -1208,7 +1241,8 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), batch_size=batch_size, s_enhance=s_enhance, t_enhance=t_enhance, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) # Load Mom2 Model if model_dir is None: @@ -1246,7 +1280,8 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1267,7 +1302,10 @@ def test_out_st_mom2_sep(plot=False, full_shape=(20, 20), a_min=0, a_max=None)) integratedSigma.append(np.mean(sigma_pred, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], hr_to_mean[:, :, j], sigma_pred[:, :, j]], @@ -1315,6 +1353,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), sample_shape=(12, 12, 24), batch_size=4, n_batches=4, s_enhance=3, t_enhance=4, + end_t_padding=False, + t_enhance_mode='constant', model_dir=None, model_mom1_dir=None): """Test basic spatiotemporal model outputing for second conditional moment @@ -1336,11 +1376,14 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), fp_gen = os.path.join(model_mom1_dir, 'model_params.json') model_mom1 = Sup3rCondMom(fp_gen).load(model_mom1_dir) - batch_handler = BatchHandlerMom2SepSF([handler], - batch_size=batch_size, - s_enhance=s_enhance, - t_enhance=t_enhance, - n_batches=n_batches) + batch_handler = BatchHandlerMom2SepSF( + [handler], + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) # Load Mom2 Model if model_dir is None: @@ -1378,7 +1421,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), b_lr_aug = np.reshape(b_lr, (1,) + b_lr.shape + (1,)) tup_lr = temporal_simple_enhancing(b_lr_aug, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode='constant') tup_lr = (tup_lr[0, :, :, :, 0] * batch_handler.stds[0] + batch_handler.means[0]) @@ -1386,7 +1430,8 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), up_lr_tmp = spatial_simple_enhancing(b_lr_aug, s_enhance=s_enhance) up_lr = temporal_simple_enhancing(up_lr_tmp, - t_enhance=t_enhance) + t_enhance=t_enhance, + mode=t_enhance_mode) up_lr = up_lr[0, :, :, :, 0] hr = (batch.high_res[i, :, :, :, 0] @@ -1407,7 +1452,10 @@ def test_out_st_mom2_sep_sf(plot=False, full_shape=(20, 20), a_min=0, a_max=None)) integratedSigma.append(np.mean(sigma_pred, axis=(0, 1))) - for j in range(batch.output.shape[3]): + max_t_ind = batch.output.shape[3] + if end_t_padding: + max_t_ind -= t_enhance + for j in range(max_t_ind): fig = plot_multi_contour( [tup_lr[:, :, j], hr[:, :, j], sf[:, :, j], sf_to_mean[:, :, j], diff --git a/tests/training/test_train_conditional_moments.py b/tests/training/test_train_conditional_moments.py index ecc789debc..927da3753a 100644 --- a/tests/training/test_train_conditional_moments.py +++ b/tests/training/test_train_conditional_moments.py @@ -400,9 +400,11 @@ def test_train_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, model_mom2.save(out_dir) -@pytest.mark.parametrize('FEATURES', - (['U_100m', 'V_100m'],)) +@pytest.mark.parametrize('FEATURES, end_t_padding', + [(['U_100m', 'V_100m'], False), + (['U_100m', 'V_100m'], True)]) def test_train_st_mom1(FEATURES, + end_t_padding, log=False, full_shape=(20, 20), sample_shape=(12, 12, 24), n_epoch=2, batch_size=2, n_batches=2, @@ -428,7 +430,8 @@ def test_train_st_mom1(FEATURES, batch_handler = BatchHandlerMom1([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -442,12 +445,16 @@ def test_train_st_mom1(FEATURES, model.save(out_dir) -@pytest.mark.parametrize('FEATURES', - (['U_100m', 'V_100m'],)) +@pytest.mark.parametrize('FEATURES, t_enhance_mode', + [(['U_100m', 'V_100m'], 'constant'), + (['U_100m', 'V_100m'], 'linear')]) def test_train_st_mom1_sf(FEATURES, + t_enhance_mode, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 24), n_epoch=2, batch_size=2, n_batches=2, + temporal_slice=slice(None, None, 1), out_dir_root=None): """Test basic spatiotemporal model training for first conditional moment of the subfilter velocity.""" @@ -464,13 +471,16 @@ def test_train_st_mom1_sf(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom1SF([handler], batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches) + batch_handler = BatchHandlerMom1SF( + [handler], batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -487,9 +497,11 @@ def test_train_st_mom1_sf(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, batch_size=2, n_batches=2, + temporal_slice=slice(None, None, 1), out_dir_root=None, model_mom1_dir=None): """Test basic spatiotemporal model training @@ -516,14 +528,15 @@ def test_train_st_mom2(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) batch_handler = BatchHandlerMom2([handler], batch_size=batch_size, s_enhance=3, t_enhance=4, n_batches=n_batches, - model_mom1=model_mom1) + model_mom1=model_mom1, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -539,8 +552,11 @@ def test_train_st_mom2(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sf(FEATURES, + t_enhance_mode='constant', + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, + temporal_slice=slice(None, None, 1), batch_size=2, n_batches=2, out_dir_root=None, model_mom1_dir=None): @@ -568,14 +584,17 @@ def test_train_st_mom2_sf(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom2SF([handler], batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches, - model_mom1=model_mom1) + batch_handler = BatchHandlerMom2SF( + [handler], batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + model_mom1=model_mom1, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -591,8 +610,10 @@ def test_train_st_mom2_sf(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sep(FEATURES, + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, + temporal_slice=slice(None, None, 1), batch_size=2, n_batches=2, out_dir_root=None): """Test basic spatiotemporal model training @@ -610,7 +631,7 @@ def test_train_st_mom2_sep(FEATURES, handler = DataHandlerH5(FP_WTK, FEATURES, target=TARGET_COORD, shape=full_shape, sample_shape=sample_shape, - temporal_slice=slice(None, None, 1), + temporal_slice=temporal_slice, val_split=0.005, worker_kwargs=dict(max_workers=1)) @@ -618,7 +639,8 @@ def test_train_st_mom2_sep(FEATURES, batch_size=batch_size, s_enhance=3, t_enhance=4, - n_batches=n_batches) + n_batches=n_batches, + end_t_padding=end_t_padding) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: @@ -634,6 +656,8 @@ def test_train_st_mom2_sep(FEATURES, @pytest.mark.parametrize('FEATURES', (['U_100m', 'V_100m'],)) def test_train_st_mom2_sep_sf(FEATURES, + t_enhance_mode='constant', + end_t_padding=False, log=False, full_shape=(20, 20), sample_shape=(12, 12, 16), n_epoch=2, batch_size=2, n_batches=2, @@ -656,10 +680,13 @@ def test_train_st_mom2_sep_sf(FEATURES, val_split=0.005, worker_kwargs=dict(max_workers=1)) - batch_handler = BatchHandlerMom2SepSF([handler], - batch_size=batch_size, - s_enhance=3, t_enhance=4, - n_batches=n_batches) + batch_handler = BatchHandlerMom2SepSF( + [handler], + batch_size=batch_size, + s_enhance=3, t_enhance=4, + n_batches=n_batches, + end_t_padding=end_t_padding, + temporal_enhancing_method=t_enhance_mode) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: