-
Notifications
You must be signed in to change notification settings - Fork 7
/
pytest.py
267 lines (234 loc) · 8.11 KB
/
pytest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# -*- coding: utf-8 -*-
"""Utilities used for pytests"""
import os
import numpy as np
import xarray as xr
from sup3r.postprocessing.file_handling import OutputHandlerH5
from sup3r.utilities.utilities import pd_date_range
def make_fake_nc_files(td, input_file, n_files):
"""Make dummy nc files with increasing times
Parameters
----------
td : str
Temporary directory
input_file : str
File to use as template for all dummy files
n_files : int
Number of dummy files to create
Returns
-------
fake_files : list
List of dummy files
"""
fake_dates = [
f'2014-10-01_{str(i).zfill(2)}_00_00' for i in range(n_files)
]
fake_times = [
f'2014-10-01 {str(i).zfill(2)}:00:00' for i in range(n_files)
]
fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates]
for i in range(n_files):
if os.path.exists(fake_files[i]):
os.remove(fake_files[i])
with xr.open_dataset(input_file) as input_dset:
with xr.Dataset(input_dset) as dset:
dset['Times'][:] = np.array(
[fake_times[i].encode('ASCII')], dtype='|S19')
dset['XTIME'][:] = i
dset.to_netcdf(fake_files[i])
return fake_files
def make_fake_multi_time_nc_files(td, input_file, n_steps, n_files):
"""Make dummy nc file with multiple timesteps
Parameters
----------
td : str
Temporary directory
input_file : str
File to use as template for timesteps in dummy file
n_steps : int
Number of timesteps across all files
n_files : int
Number of files to split all timsteps across
Returns
-------
fake_file : str
multi timestep dummy file
"""
fake_files = make_fake_nc_files(td, input_file, n_steps)
fake_files = np.array_split(fake_files, n_files)
dummy_files = []
for i, files in enumerate(fake_files):
dummy_file = os.path.join(
td, f'multi_timestep_file_{str(i).zfill(3)}.nc')
if os.path.exists(dummy_file):
os.remove(dummy_file)
dummy_files.append(dummy_file)
with xr.open_mfdataset(
files, combine='nested', concat_dim='Time') as dset:
dset.to_netcdf(dummy_file)
return dummy_files
def make_fake_era_files(td, input_file, n_files):
"""Make dummy era files with increasing times. ERA files have a different
naming convention than WRF.
Parameters
----------
td : str
Temporary directory
input_file : str
File to use as template for all dummy files
n_files : int
Number of dummy files to create
Returns
-------
fake_files : list
List of dummy files
"""
fake_dates = [
f'2014-10-01_{str(i).zfill(2)}_00_00' for i in range(n_files)
]
fake_times = [
f'2014-10-01 {str(i).zfill(2)}:00:00' for i in range(n_files)
]
fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates]
for i in range(n_files):
if os.path.exists(fake_files[i]):
os.remove(fake_files[i])
with xr.open_dataset(input_file) as input_dset:
with xr.Dataset(input_dset) as dset:
dset['Times'][:] = np.array(
[fake_times[i].encode('ASCII')], dtype='|S19')
dset['XTIME'][:] = i
dset = dset.rename({'U': 'u', 'V': 'v'})
dset.to_netcdf(fake_files[i])
return fake_files
def make_fake_h5_chunks(td):
"""Make fake h5 chunked output files for a 5x spatial 2x temporal
multi-node forward pass output.
Parameters
----------
td : tempfile.TemporaryDirectory
Test TemporaryDirectory
Returns
-------
out_files : list
List of filepaths to chunked files.
data : ndarray
(spatial_1, spatial_2, temporal, features)
High resolution forward pass output
ws_true : ndarray
Windspeed between 0 and 20 in shape (spatial_1, spatial_2, temporal, 1)
wd_true : ndarray
Windir between 0 and 360 in shape (spatial_1, spatial_2, temporal, 1)
features : list
List of feature names corresponding to the last dimension of data
['windspeed_100m', 'winddirection_100m']
t_slices_lr : list
List of low res temporal slices
t_slices_hr : list
List of high res temporal slices
s_slices_lr : list
List of low res spatial slices
s_slices_hr : list
List of high res spatial slices
low_res_lat_lon : ndarray
Array of lat/lon for input data. (spatial_1, spatial_2, 2)
Last dimension has ordering (lat, lon)
low_res_times : list
List of np.datetime64 objects for coarse data.
"""
features = ['windspeed_100m', 'winddirection_100m']
model_meta_data = {'foo': 'bar'}
shape = (50, 50, 96, 1)
ws_true = np.random.uniform(0, 20, shape)
wd_true = np.random.uniform(0, 360, shape)
data = np.concatenate((ws_true, wd_true), axis=3)
lat = np.linspace(90, 0, 10)
lon = np.linspace(-180, 0, 10)
lon, lat = np.meshgrid(lon, lat)
low_res_lat_lon = np.dstack((lat, lon))
gids = np.arange(np.prod(shape[:2]))
gids = gids.reshape(shape[:2])
low_res_times = pd_date_range(
'20220101', '20220103', freq='3600s', inclusive='left'
)
t_slices_lr = [slice(0, 24), slice(24, None)]
t_slices_hr = [slice(0, 48), slice(48, None)]
s_slices_lr = [slice(0, 5), slice(5, 10)]
s_slices_hr = [slice(0, 25), slice(25, 50)]
out_pattern = os.path.join(td, 'fp_out_{t}_{i}_{j}.h5')
out_files = []
for t, (slice_lr, slice_hr) in enumerate(zip(t_slices_lr, t_slices_hr)):
for i, (s1_lr, s1_hr) in enumerate(zip(s_slices_lr, s_slices_hr)):
for j, (s2_lr, s2_hr) in enumerate(zip(s_slices_lr, s_slices_hr)):
out_file = out_pattern.format(
t=str(t).zfill(3),
i=str(i).zfill(3),
j=str(j).zfill(3),
)
out_files.append(out_file)
OutputHandlerH5.write_output(
data[s1_hr, s2_hr, slice_hr, :],
features,
low_res_lat_lon[s1_lr, s2_lr],
low_res_times[slice_lr],
out_file,
meta_data=model_meta_data,
max_workers=1,
gids=gids[s1_hr, s2_hr],
)
out = (
out_files,
data,
ws_true,
wd_true,
features,
t_slices_lr,
t_slices_hr,
s_slices_lr,
s_slices_hr,
low_res_lat_lon,
low_res_times,
)
return out
def make_fake_cs_ratio_files(td, low_res_times, low_res_lat_lon, gan_meta):
"""Make a set of dummy clearsky ratio files that match the GAN fwp outputs
Parameters
----------
td : tempfile.TemporaryDirectory
Test TemporaryDirectory
low_res_times :
List of times for low res input data. If there is only a single low
res timestep, it is assumed the data is daily.
low_res_lat_lon
Array of lat/lon for input data.
(spatial_1, spatial_2, 2)
Last dimension has ordering (lat, lon)
gan_meta : dict
Meta data for model to write to file.
Returns
-------
fps : list
List of clearsky ratio .h5 chunked files.
fp_pattern : str
Glob pattern*string to find fps
"""
fps = []
chunk_dir = os.path.join(td, 'chunks/')
fp_pattern = os.path.join(chunk_dir, 'sup3r_chunk_*.h5')
os.makedirs(chunk_dir)
for idt, timestamp in enumerate(low_res_times):
fn = 'sup3r_chunk_{}_{}.h5'.format(str(idt).zfill(6), str(0).zfill(6))
out_file = os.path.join(chunk_dir, fn)
fps.append(out_file)
cs_ratio = np.random.uniform(0, 1, (20, 20, 1, 1))
cs_ratio = np.repeat(cs_ratio, 24, axis=2)
OutputHandlerH5.write_output(
cs_ratio,
['clearsky_ratio'],
low_res_lat_lon,
[timestamp],
out_file,
max_workers=1,
meta_data=gan_meta,
)
return fps, fp_pattern