-
Notifications
You must be signed in to change notification settings - Fork 8
/
wind.py
137 lines (114 loc) · 5.21 KB
/
wind.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# -*- coding: utf-8 -*-
"""Wind super resolution GAN with handling of low and high res topography
inputs."""
import numpy as np
import logging
import tensorflow as tf
from sup3r.models.base import Sup3rGan
from sup3r.models.abstract import AbstractWindInterface
logger = logging.getLogger(__name__)
class WindGan(AbstractWindInterface, Sup3rGan):
"""Wind super resolution GAN with handling of low and high res topography
inputs.
Modifications to standard Sup3rGan:
- Hi res topography is expected as the last feature channel in the true
data in the true batch observation. This topo channel is appended to
the generated output so the discriminator can look at the wind fields
compared to the associated hi res topo.
- If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present
in the network, the hi-res topography will be added or concatenated
to the data at that point in the network during either training or
the forward pass.
"""
def init_weights(self, lr_shape, hr_shape, device=None):
"""Initialize the generator and discriminator weights with device
placement.
Parameters
----------
lr_shape : tuple
Shape of one batch of low res input data for sup3r resolution. Note
that the batch size (axis=0) must be included, but the actual batch
size doesnt really matter.
hr_shape : tuple
Shape of one batch of high res input data for sup3r resolution.
Note that the batch size (axis=0) must be included, but the actual
batch size doesnt really matter.
device : str | None
Option to place model weights on a device. If None,
self.default_device will be used.
"""
if device is None:
device = self.default_device
logger.info('Initializing model weights on device "{}"'.format(device))
low_res = np.ones(lr_shape).astype(np.float32)
hi_res = np.ones(hr_shape).astype(np.float32)
hr_topo_shape = hr_shape[:-1] + (1,)
hr_topo = np.ones(hr_topo_shape).astype(np.float32)
with tf.device(device):
_ = self._tf_generate(low_res, hr_topo)
_ = self._tf_discriminate(hi_res)
def set_model_params(self, **kwargs):
"""Set parameters used for training the model
Parameters
----------
kwargs : dict
Keyword arguments including 'training_features', 'output_features',
'smoothed_features', 's_enhance', 't_enhance', 'smoothing'
"""
AbstractWindInterface.set_model_params(**kwargs)
Sup3rGan.set_model_params(self, **kwargs)
@tf.function
def calc_loss(self, hi_res_true, hi_res_gen, **kwargs):
"""Calculate the GAN loss function using generated and true high
resolution data.
Parameters
----------
hi_res_true : tf.Tensor
Ground truth high resolution spatiotemporal data.
hi_res_gen : tf.Tensor
Superresolved high resolution spatiotemporal data generated by the
generative model.
kwargs : dict
Key word arguments for:
Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs)
Returns
-------
loss : tf.Tensor
0D tensor representing the loss value for the network being trained
(either generator or one of the discriminators)
loss_details : dict
Namespace of the breakdown of loss components
"""
# append the true topography to the generated synthetic wind data
hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1)
return super().calc_loss(hi_res_true, hi_res_gen, **kwargs)
def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details):
"""Calculate the validation loss at the current state of model training
Parameters
----------
batch_handler : sup3r.data_handling.preprocessing.BatchHandler
BatchHandler object to iterate through
weight_gen_advers : float
Weight factor for the adversarial loss component of the generator
vs. the discriminator.
loss_details : dict
Namespace of the breakdown of loss components
Returns
-------
loss_details : dict
Same as input but now includes val_* loss info
"""
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
high_res_gen = self._tf_generate(val_batch.low_res,
val_batch.high_res[..., -1:])
_, v_loss_details = self.calc_loss(
val_batch.high_res, high_res_gen,
weight_gen_advers=weight_gen_advers,
train_gen=False, train_disc=False)
loss_details = self.update_loss_details(loss_details,
v_loss_details,
len(val_batch),
prefix='val_')
return loss_details