Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix for conditional moment with topography #149

Merged
merged 11 commits into from
Jul 17, 2023
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Option 2: Clone repo (recommended for developers)
6) Install ``sup3r`` and its dependencies by running:
``pip install .`` (or ``pip install -e .`` if running a dev branch
or working on the source code)
7) *Optional*: Set up the pre-commit hooks with ``pip install pre-commit`` and ``pre-commit install``

Recommended Citation
====================
Expand Down
43 changes: 43 additions & 0 deletions sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"hidden_layers": [
{"n": 2, "repeat": [
{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2},
{"alpha": 0.2, "class": "LeakyReLU"},
{"class": "SpatioTemporalExpansion", "temporal_mult": 2, "temporal_method": "nearest"}
]
},
{"class": "SkipConnection", "name": "a"},

{"n": 1, "repeat": [
{"class": "SkipConnection", "name": "b"},
{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2},
{"alpha": 0.2, "class": "LeakyReLU"},
{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2},
{"class": "SkipConnection", "name": "b"}
]
},

{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2},
{"class": "SkipConnection", "name": "a"},

{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 36, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2},
{"class": "SpatioTemporalExpansion", "spatial_mult": 3},
{"alpha": 0.2, "class": "LeakyReLU"},

{"class": "Sup3rConcat"},

{"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"},
{"class": "Conv3D", "filters": 2, "kernel_size": 3, "strides": 1},
{"class": "Cropping3D", "cropping": 2}
]
}
2 changes: 1 addition & 1 deletion sup3r/models/conditional_moments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""Sup3r model software"""
"""Sup3r conditional moment model software"""
import os
import time
import logging
Expand Down
145 changes: 97 additions & 48 deletions sup3r/preprocessing/conditional_moment_batch_handling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
"""
Sup3r batch_handling module.
Sup3r conditional moment batch_handling module.
"""
import logging
import numpy as np
Expand Down Expand Up @@ -97,6 +97,14 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either constant or linear

Returns
-------
HR: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
HR is high-res and LR is low-res
"""
return high_res

Expand Down Expand Up @@ -147,6 +155,13 @@ def make_mask(high_res,
output_features_ind : list | np.ndarray | None
List/array of feature channel indices that are used for generative
output, without any feature indices used only for training.

Returns
-------
mask: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
"""
mask = np.zeros(high_res.shape, dtype=np.float32)
s_min = s_padding if s_padding is not None else 0
Expand Down Expand Up @@ -270,8 +285,8 @@ def get_coarse_batch(cls, high_res,


class BatchMom1SF(BatchMom1):
"""Batch of low_res, high_res and output data
when learning first moment of subfilter vel"""
"""Batch of low_res, high_res and output data when learning first moment
of subfilter vel"""

@staticmethod
def make_output(low_res, high_res,
Expand Down Expand Up @@ -302,6 +317,15 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either constant or linear

Returns
-------
SF: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
SF is subfilter, HR is high-res and LR is low-res
SF = HR - LR
"""
# Remove LR from HR
enhanced_lr = spatial_simple_enhancing(low_res,
Expand All @@ -310,12 +334,13 @@ def make_output(low_res, high_res,
t_enhance=t_enhance,
mode=t_enhance_mode)
enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind)

return high_res - enhanced_lr


class BatchMom2(BatchMom1):
"""Batch of low_res, high_res and output data
when learning second moment"""
"""Batch of low_res, high_res and output data when learning second
moment"""

@staticmethod
def make_output(low_res, high_res,
Expand Down Expand Up @@ -346,15 +371,23 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either constant or linear

Returns
-------
(HR - <HR|LR>)**2: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
HR is high-res and LR is low-res
"""
# Remove first moment from HR and square it
out = model_mom1._tf_generate(low_res).numpy()
return (high_res - out)**2


class BatchMom2Sep(BatchMom1):
"""Batch of low_res, high_res and output data
when learning second moment separate from first moment"""
"""Batch of low_res, high_res and output data when learning second moment
separate from first moment"""

@staticmethod
def make_output(low_res, high_res,
Expand Down Expand Up @@ -385,6 +418,14 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either constant or linear

Returns
-------
HR**2: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
HR is high-res
"""
return super(BatchMom2Sep,
BatchMom2Sep).make_output(low_res, high_res,
Expand All @@ -395,8 +436,8 @@ def make_output(low_res, high_res,


class BatchMom2SF(BatchMom1):
"""Batch of low_res, high_res and output data
when learning second moment of subfilter vel"""
"""Batch of low_res, high_res and output data when learning second moment
of subfilter vel"""

@staticmethod
def make_output(low_res, high_res,
Expand Down Expand Up @@ -427,6 +468,15 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either 'constant' or 'linear'

Returns
-------
(SF - <SF|LR>)**2: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
SF is subfilter, HR is high-res and LR is low-res
SF = HR - LR
"""
# Remove LR and first moment from HR and square it
out = model_mom1._tf_generate(low_res).numpy()
Expand All @@ -440,9 +490,8 @@ def make_output(low_res, high_res,


class BatchMom2SepSF(BatchMom1SF):
"""Batch of low_res, high_res and output data
when learning second moment of subfilter vel
separate from first moment"""
"""Batch of low_res, high_res and output data when learning second moment
of subfilter vel separate from first moment"""

@staticmethod
def make_output(low_res, high_res,
Expand Down Expand Up @@ -473,6 +522,15 @@ def make_output(low_res, high_res,
t_enhance_mode : str
Enhancing mode for temporal subfilter.
Can be either constant or linear

Returns
-------
SF**2: np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
SF is subfilter, HR is high-res and LR is low-res
SF = HR - LR
"""
# Remove LR from HR and square it
return super(BatchMom2SepSF,
Expand Down Expand Up @@ -854,106 +912,97 @@ def __next__(self):


class ValidationDataMom1SF(ValidationDataMom1):
"""Iterator for validation data for
first conditional moment of subfilter velocity"""
"""Iterator for validation data for first conditional moment of subfilter
velocity"""
BATCH_CLASS = BatchMom1SF


class ValidationDataMom2(ValidationDataMom1):
"""Iterator for subfilter validation data for
second conditional moment"""
"""Iterator for subfilter validation data for second conditional moment"""
BATCH_CLASS = BatchMom2


class ValidationDataMom2Sep(ValidationDataMom1):
"""Iterator for subfilter validation data for
second conditional moment separate from first
moment"""
"""Iterator for subfilter validation data for second conditional moment
separate from first moment"""
BATCH_CLASS = BatchMom2Sep


class ValidationDataMom2SF(ValidationDataMom1):
"""Iterator for validation data for
second conditional moment of subfilter velocity"""
"""Iterator for validation data for second conditional moment of subfilter
velocity"""
BATCH_CLASS = BatchMom2SF


class ValidationDataMom2SepSF(ValidationDataMom1):
"""Iterator for validation data for
second conditional moment of subfilter velocity
separate from first moment"""
"""Iterator for validation data for second conditional moment of subfilter
velocity separate from first moment"""
BATCH_CLASS = BatchMom2SepSF


class BatchHandlerMom1SF(BatchHandlerMom1):
"""Sup3r batch handling class for
first conditional moment of subfilter velocity"""
"""Sup3r batch handling class for first conditional moment of subfilter
velocity"""
VAL_CLASS = ValidationDataMom1SF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class SpatialBatchHandlerMom1SF(SpatialBatchHandlerMom1):
"""Sup3r spatial batch handling class for
first conditional moment of subfilter velocity"""
"""Sup3r spatial batch handling class for first conditional moment of
subfilter velocity"""
VAL_CLASS = ValidationDataMom1SF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class BatchHandlerMom2(BatchHandlerMom1):
"""Sup3r batch handling class for
second conditional moment"""
"""Sup3r batch handling class for second conditional moment"""
VAL_CLASS = ValidationDataMom2
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class BatchHandlerMom2Sep(BatchHandlerMom1):
"""Sup3r batch handling class for
second conditional moment separate from first
moment"""
"""Sup3r batch handling class for second conditional moment separate from
first moment"""
VAL_CLASS = ValidationDataMom2Sep
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class SpatialBatchHandlerMom2(SpatialBatchHandlerMom1):
"""Sup3r spatial batch handling class for
second conditional moment"""
"""Sup3r spatial batch handling class for second conditional moment"""
VAL_CLASS = ValidationDataMom2
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class SpatialBatchHandlerMom2Sep(SpatialBatchHandlerMom1):
"""Sup3r spatial batch handling class for
second conditional moment separate from first
moment"""
"""Sup3r spatial batch handling class for second conditional moment
separate from first moment"""
VAL_CLASS = ValidationDataMom2Sep
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class BatchHandlerMom2SF(BatchHandlerMom1):
"""Sup3r batch handling class for
second conditional moment of subfilter velocity"""
"""Sup3r batch handling class for second conditional moment of subfilter
velocity"""
VAL_CLASS = ValidationDataMom2SF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class BatchHandlerMom2SepSF(BatchHandlerMom1):
"""Sup3r batch handling class for
second conditional moment of subfilter velocity
separate from first moment"""
"""Sup3r batch handling class for second conditional moment of subfilter
velocity separate from first moment"""
VAL_CLASS = ValidationDataMom2SepSF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class SpatialBatchHandlerMom2SF(SpatialBatchHandlerMom1):
"""Sup3r spatial batch handling class for
second conditional moment of subfilter velocity"""
"""Sup3r spatial batch handling class for second conditional moment of
subfilter velocity"""
VAL_CLASS = ValidationDataMom2SF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS


class SpatialBatchHandlerMom2SepSF(SpatialBatchHandlerMom1):
"""Sup3r spatial batch handling class for
second conditional moment of subfilter velocity
separate from first moment"""
"""Sup3r spatial batch handling class for second conditional moment of
subfilter velocity separate from first moment"""
VAL_CLASS = ValidationDataMom2SepSF
BATCH_CLASS = VAL_CLASS.BATCH_CLASS
Loading