Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to train second moment separately from first moment #125

Merged
merged 6 commits into from
Dec 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ TO DO
#. |ss| Nomenclature upsampling enhance |se|
#. |ss| See of Sup3rCondMom can instead inherit from abstract Gan class |se|
#. |ss| Add option to crop output |se|
#. |ss| Add option to train first and second moment separately |se|
#. Figure out why Validation loss is always lower than training
#. Train network with increasing complexity
#. |ss| Include number of parameter in loss plotting |se|
#. Show training results
Expand Down
119 changes: 119 additions & 0 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,39 @@ def make_output(low_res, high_res,
return (high_res - out)**2


class BatchMom2Sep(Batch):
"""Batch of low_res, high_res and output data
when learning second moment separate from first moment"""

@staticmethod
def make_output(low_res, high_res,
s_enhance=None, t_enhance=None,
model_mom1=None, output_features_ind=None):
"""Make custom batch output

Parameters
----------
low_res : np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
high_res : np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
s_enhance : int | None
Spatial enhancement factor
t_enhance : int | None
Temporal enhancement factor
model_mom1 : Sup3rCondMom | None
Model used to modify the make the batch output
output_features_ind : list | np.ndarray | None
List/array of feature channel indices that are used for generative
output, without any feature indices used only for training.
"""
return high_res**2


class BatchMom2SF(Batch):
"""Batch of low_res, high_res and output data
when learning second moment of subfilter vel"""
Expand Down Expand Up @@ -338,6 +371,46 @@ def make_output(low_res, high_res,
return (high_res - enhanced_lr - out)**2


class BatchMom2SepSF(Batch):
"""Batch of low_res, high_res and output data
when learning second moment of subfilter vel
separate from first moment"""

@staticmethod
def make_output(low_res, high_res,
s_enhance=None, t_enhance=None,
model_mom1=None, output_features_ind=None):
"""Make custom batch output

Parameters
----------
low_res : np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
high_res : np.ndarray
4D | 5D array
(batch_size, spatial_1, spatial_2, features)
(batch_size, spatial_1, spatial_2, temporal, features)
s_enhance : int | None
Spatial enhancement factor
t_enhance : int | None
Temporal enhancement factor
model_mom1 : Sup3rCondMom | None
Model used to modify the make the batch output
output_features_ind : list | np.ndarray | None
List/array of feature channel indices that are used for generative
output, without any feature indices used only for training.
"""
# Remove first moment from high res and square it
enhanced_lr = spatial_simple_enhancing(low_res,
s_enhance=s_enhance)
enhanced_lr = temporal_simple_enhancing(enhanced_lr,
t_enhance=t_enhance)
enhanced_lr = Batch.reduce_features(enhanced_lr, output_features_ind)
return (high_res - enhanced_lr)**2


class BatchMom1SF(Batch):
"""Batch of low_res, high_res and output data
when learning first moment of subfilter vel"""
Expand Down Expand Up @@ -1397,40 +1470,86 @@ class ValidationDataMom2(ValidationData):
BATCH_CLASS = BatchMom2


class ValidationDataMom2Sep(ValidationData):
"""Iterator for subfilter validation data for
second conditional moment separate from first
moment"""
BATCH_CLASS = BatchMom2Sep


class BatchHandlerMom2(BatchHandler):
"""Sup3r batch handling class for
second conditional moment"""
VAL_CLASS = ValidationDataMom2
BATCH_CLASS = BatchMom2


class BatchHandlerMom2Sep(BatchHandler):
"""Sup3r batch handling class for
second conditional moment separate from first
moment"""
VAL_CLASS = ValidationDataMom2Sep
BATCH_CLASS = BatchMom2Sep


class SpatialBatchHandlerMom2(SpatialBatchHandler):
"""Sup3r spatial batch handling class for
second conditional moment"""
VAL_CLASS = ValidationDataMom2
BATCH_CLASS = BatchMom2


class SpatialBatchHandlerMom2Sep(SpatialBatchHandler):
"""Sup3r spatial batch handling class for
second conditional moment separate from first
moment"""
VAL_CLASS = ValidationDataMom2Sep
BATCH_CLASS = BatchMom2Sep


class ValidationDataMom2SF(ValidationData):
"""Iterator for validation data for
second conditional moment of subfilter velocity"""
BATCH_CLASS = BatchMom2SF


class ValidationDataMom2SepSF(ValidationData):
"""Iterator for validation data for
second conditional moment of subfilter velocity
separate from first moment"""
BATCH_CLASS = BatchMom2SepSF


class BatchHandlerMom2SF(BatchHandler):
"""Sup3r batch handling class for
second conditional moment of subfilter velocity"""
VAL_CLASS = ValidationDataMom2SF
BATCH_CLASS = BatchMom2SF


class BatchHandlerMom2SepSF(BatchHandler):
"""Sup3r batch handling class for
second conditional moment of subfilter velocity
separate from first moment"""
VAL_CLASS = ValidationDataMom2SepSF
BATCH_CLASS = BatchMom2SepSF


class SpatialBatchHandlerMom2SF(SpatialBatchHandler):
"""Sup3r spatial batch handling class for
second conditional moment of subfilter velocity"""
VAL_CLASS = ValidationDataMom2SF
BATCH_CLASS = BatchMom2SF


class SpatialBatchHandlerMom2SepSF(SpatialBatchHandler):
"""Sup3r spatial batch handling class for
second conditional moment of subfilter velocity
separate from first moment"""
VAL_CLASS = ValidationDataMom2SepSF
BATCH_CLASS = BatchMom2SepSF


class ValidationDataMom1SF(ValidationData):
"""Iterator for validation data for
first conditional moment of subfilter velocity"""
Expand Down
22 changes: 22 additions & 0 deletions tests/run_out_conditional_moments_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from test_out_conditional_moments import (test_out_spatial_mom1,
test_out_spatial_mom1_sf,
test_out_spatial_mom2,
test_out_spatial_mom2_sep,
test_out_spatial_mom2_sf,
test_out_spatial_mom2_sep_sf,
test_out_loss)

FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
Expand Down Expand Up @@ -49,6 +51,15 @@
model_mom1_dir='s_mom1_feat/spatial_cond_mom',
TRAIN_FEATURES=TRAIN_FEATURES)

test_out_spatial_mom2_sep(plot=True, full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=4, n_batches=2,
s_enhance=2,
FEATURES=FEATURES,
model_dir='s_mom2_sep_feat/spatial_cond_mom',
model_mom1_dir='s_mom1_feat/spatial_cond_mom',
TRAIN_FEATURES=TRAIN_FEATURES)

test_out_spatial_mom1_sf(plot=True, full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=4, n_batches=2,
Expand All @@ -65,3 +76,14 @@
model_mom1_dir='s_mom1_sf_feat/spatial_cond_mom',
model_dir='s_mom2_sf_feat/spatial_cond_mom',
TRAIN_FEATURES=TRAIN_FEATURES)

test_out_spatial_mom2_sep_sf(
plot=True,
full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=4, n_batches=2,
s_enhance=2,
FEATURES=FEATURES,
model_mom1_dir='s_mom1_sf_feat/spatial_cond_mom',
model_dir='s_mom2_sep_sf_feat/spatial_cond_mom',
TRAIN_FEATURES=TRAIN_FEATURES)
18 changes: 17 additions & 1 deletion tests/run_out_conditional_moments_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from test_out_conditional_moments import (test_out_loss,
test_out_st_mom1,
test_out_st_mom2,
test_out_st_mom2_sep,
test_out_st_mom1_sf,
test_out_st_mom2_sf)
test_out_st_mom2_sf,
test_out_st_mom2_sep_sf)

FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
Expand Down Expand Up @@ -41,6 +43,13 @@
model_dir='st_mom2/st_cond_mom',
model_mom1_dir='st_mom1/st_cond_mom')

test_out_st_mom2_sep(plot=True, full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=1, n_batches=1,
s_enhance=3, t_enhance=4,
model_dir='st_mom2_sep/st_cond_mom',
model_mom1_dir='st_mom1/st_cond_mom')

test_out_st_mom1_sf(plot=True, full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=1, n_batches=1,
Expand All @@ -53,3 +62,10 @@
s_enhance=3, t_enhance=4,
model_mom1_dir='st_mom1_sf/st_cond_mom',
model_dir='st_mom2_sf/st_cond_mom')

test_out_st_mom2_sep_sf(plot=True, full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=1, n_batches=1,
s_enhance=3, t_enhance=4,
model_mom1_dir='st_mom1_sf/st_cond_mom',
model_dir='st_mom2_sep_sf/st_cond_mom')
23 changes: 22 additions & 1 deletion tests/run_train_conditional_moments_feat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from sup3r import TEST_DATA_DIR
from test_train_conditional_moments import (test_train_spatial_mom1,
test_train_spatial_mom2,
test_train_spatial_mom2_sep,
test_train_spatial_mom1_sf,
test_train_spatial_mom2_sf)
test_train_spatial_mom2_sf,
test_train_spatial_mom2_sep_sf)


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
Expand Down Expand Up @@ -35,6 +37,15 @@
s_padding=None,
t_padding=None)

test_train_spatial_mom2_sep(n_epoch=2, log=True, full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=8, n_batches=5,
out_dir_root='s_mom2_sep_feat',
FEATURES=FEATURES,
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=None,
t_padding=None)

test_train_spatial_mom1_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(10, 10, 1),
Expand All @@ -56,4 +67,14 @@
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=None,
t_padding=None)

test_train_spatial_mom2_sep_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=8, n_batches=5,
out_dir_root='s_mom2_sep_sf_feat',
FEATURES=FEATURES,
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=None,
t_padding=None)
# pass
21 changes: 20 additions & 1 deletion tests/run_train_conditional_moments_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from test_train_conditional_moments import (test_train_spatial_mom1,
test_train_spatial_mom2,
test_train_spatial_mom2_sep,
test_train_spatial_mom1_sf,
test_train_spatial_mom2_sf)
test_train_spatial_mom2_sf,
test_train_spatial_mom2_sep_sf)


FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
Expand All @@ -33,6 +35,14 @@
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=1, t_padding=1)

test_train_spatial_mom2_sep(n_epoch=2, log=True, full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=8, n_batches=5,
out_dir_root='s_mom2_sep_pad',
FEATURES=FEATURES,
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=1, t_padding=1)

test_train_spatial_mom1_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(10, 10, 1),
Expand All @@ -52,4 +62,13 @@
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=1, t_padding=1)

test_train_spatial_mom2_sep_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(10, 10, 1),
batch_size=8, n_batches=5,
out_dir_root='s_mom2_sep_sf_pad',
FEATURES=FEATURES,
TRAIN_FEATURES=TRAIN_FEATURES,
s_padding=1, t_padding=1)

# pass
16 changes: 15 additions & 1 deletion tests/run_train_conditional_moments_st.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from test_train_conditional_moments import (test_train_st_mom1,
test_train_st_mom1_sf,
test_train_st_mom2,
test_train_st_mom2_sf)
test_train_st_mom2_sep,
test_train_st_mom2_sf,
test_train_st_mom2_sep_sf)

FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5')
TARGET_COORD = (39.01, -105.15)
Expand All @@ -33,11 +35,23 @@
out_dir_root='st_mom2',
model_mom1_dir='st_mom1/st_cond_mom',
FEATURES=FEATURES)
test_train_st_mom2_sep(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=2, n_batches=2,
out_dir_root='st_mom2_sep',
FEATURES=FEATURES)
test_train_st_mom2_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=2, n_batches=2,
out_dir_root='st_mom2_sf',
model_mom1_dir='st_mom1_sf/st_cond_mom',
FEATURES=FEATURES)
test_train_st_mom2_sep_sf(n_epoch=2, log=True,
full_shape=(20, 20),
sample_shape=(12, 12, 24),
batch_size=2, n_batches=2,
out_dir_root='st_mom2_sep_sf',
FEATURES=FEATURES)
# pass
Loading