-
Notifications
You must be signed in to change notification settings - Fork 7
/
wind_conditional_moments.py
98 lines (81 loc) · 3.63 KB
/
wind_conditional_moments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
"""Wind conditional moment estimator with handling of low and
high res topography inputs."""
import logging
import tensorflow as tf
from sup3r.models.abstract import AbstractWindInterface
from sup3r.models.conditional_moments import Sup3rCondMom
logger = logging.getLogger(__name__)
class WindCondMom(AbstractWindInterface, Sup3rCondMom):
"""Wind conditional moment estimator with handling of low and
high res topography inputs.
Modifications to standard Sup3rCondMom:
- Hi res topography is expected as the last feature channel in the true
data in the true batch observation.
- If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present
in the network, the hi-res topography will be added or concatenated
to the data at that point in the network during either training or
the forward pass.
"""
def set_model_params(self, **kwargs):
"""Set parameters used for training the model
Parameters
----------
kwargs : dict
Keyword arguments including 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing'
"""
AbstractWindInterface.set_model_params(**kwargs)
Sup3rCondMom.set_model_params(self, **kwargs)
@tf.function
def calc_loss(self, hi_res_true, hi_res_gen, mask, **kwargs):
"""Calculate the loss function using generated and true high
resolution data.
Parameters
----------
hi_res_true : tf.Tensor
Ground truth high resolution spatiotemporal data.
hi_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
mask : tf.Tensor
Mask to apply
kwargs : dict
Key word arguments for:
Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs)
Returns
-------
loss : tf.Tensor
0D tensor representing the loss value for the network being trained
(either generator or one of the discriminators)
loss_details : dict
Namespace of the breakdown of loss components
"""
# append the true topography to the generated synthetic wind data
hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1)
return super().calc_loss(hi_res_true, hi_res_gen, mask, **kwargs)
def calc_val_loss(self, batch_handler, loss_details):
"""Calculate the validation loss at the current state of model training
Parameters
----------
batch_handler : sup3r.data_handling.preprocessing.BatchHandler
BatchHandler object to iterate through
loss_details : dict
Namespace of the breakdown of loss components
Returns
-------
loss_details : dict
Same as input but now includes val_* loss info
"""
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
high_res_gen = self._tf_generate(val_batch.low_res,
val_batch.high_res[..., -1:])
_, v_loss_details = self.calc_loss(
val_batch.output, high_res_gen, val_batch.mask)
loss_details = self.update_loss_details(loss_details,
v_loss_details,
len(val_batch),
prefix='val_')
return loss_details