Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/bias nc #181

Merged
merged 15 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Next Next commit
refactor: sub-pixel aggregation for bias correction instead of NN dis…
…tance lookup
  • Loading branch information
grantbuster committed Dec 6, 2023
commit 65ecf068a88b0a8f1ce86d5d52f0cf568a556447
116 changes: 68 additions & 48 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ def __init__(self,
bias_fps,
base_dset,
bias_feature,
distance_upper_bound,
target=None,
shape=None,
base_handler='Resource',
bias_handler='DataHandlerNCforCC',
base_handler_kwargs=None,
bias_handler_kwargs=None,
decimals=None):
"""
Expand All @@ -60,18 +62,25 @@ def __init__(self,
bias_feature : str
This is the biased feature from bias_fps to retrieve. This should
be a single feature name corresponding to base_dset
distance_upper_bound : float
Upper bound on the nearest neighbor distance in decimal degrees.
This should be the approximate resolution of the low-resolution
bias data.
target : tuple
(lat, lon) lower left corner of raster to retrieve from bias_fps.
If None then the lower left corner of the full domain will be used.
shape : tuple
(rows, cols) grid size to retrieve from bias_fps. If None then the
full domain shape will be used.
base_handler : str
Name of rex resource handler class to be retrieved from the rex
library.
Name of rex resource handler or sup3r.preprocessing.data_handling
class to be retrieved from the rex/sup3r library.
bias_handler : str
Name of the bias data handler class to be retrieved from the
sup3r.preprocessing.data_handling library.
base_handler_kwargs : dict | None
Optional kwargs to send to the initialization of the base_handler
class
bias_handler_kwargs : dict | None
Optional kwargs to send to the initialization of the bias_handler
class
Expand All @@ -92,36 +101,49 @@ def __init__(self,
self.target = target
self.shape = shape
self.decimals = decimals
bias_handler_kwargs = bias_handler_kwargs or {}
self.base_handler_kwargs = base_handler_kwargs or {}
self.bias_handler_kwargs = bias_handler_kwargs or {}
self.bad_bias_gids = []

if isinstance(self.base_fps, str):
self.base_fps = sorted(glob(self.base_fps))
if isinstance(self.bias_fps, str):
self.bias_fps = sorted(glob(self.bias_fps))

self.base_handler = getattr(rex, base_handler)
base_sup3r_h = getattr(sup3r.preprocessing.data_handling,
base_handler, None)
base_rex_h = getattr(rex, base_handler, None)
msg = f'Could not retrieve "{base_handler}" from sup3r or rex!'
assert base_sup3r_h is not None or base_rex_h is not None, msg
self.base_handler = base_rex_h or base_sup3r_h

self.bias_handler = getattr(sup3r.preprocessing.data_handling,
bias_handler)

with self.base_handler(self.base_fps[0]) as res:
self.base_meta = res.meta
self.base_tree = KDTree(self.base_meta[['latitude', 'longitude']])
self.base_dh = self.base_handler(self.base_fps[0],
**self.base_handler_kwargs)
self.base_meta = self.base_dh.meta

self.bias_dh = self.bias_handler(self.bias_fps, [self.bias_feature],
target=self.target,
shape=self.shape,
val_split=0.0,
**bias_handler_kwargs)
**self.bias_handler_kwargs)
lats = self.bias_dh.lat_lon[..., 0].flatten()
lons = self.bias_dh.lat_lon[..., 1].flatten()
self.bias_meta = pd.DataFrame({'latitude': lats, 'longitude': lons})
self.bias_meta = self.bias_dh.meta
self.bias_ti = self.bias_dh.time_index

raster_shape = self.bias_dh.lat_lon[..., 0].shape
self.bias_tree = KDTree(self.bias_meta[['latitude', 'longitude']])
bias_lat_lon = self.bias_meta[['latitude', 'longitude']].values
self.bias_tree = KDTree(bias_lat_lon)
self.bias_gid_raster = np.arange(lats.size)
self.bias_gid_raster = self.bias_gid_raster.reshape(raster_shape)

out = self.bias_tree.query(self.base_meta[['latitude', 'longitude']],
k=1,
distance_upper_bound=distance_upper_bound)
self.nn_dist, self.nn_ind = out

self.out = None
self._init_out()
logger.info('Finished initializing DataRetrievalBase.')
Expand Down Expand Up @@ -238,7 +260,7 @@ def get_bias_gid(self, coord):
bias_gid = self.bias_gid_raster.flatten()[i]
return bias_gid, d

def get_base_gid(self, bias_gid, knn):
def get_base_gid(self, bias_gid):
"""Get one or more base gid(s) corresponding to a bias gid.

Parameters
Expand All @@ -247,32 +269,29 @@ def get_base_gid(self, bias_gid, knn):
gid of the data to retrieve in the bias data source raster data.
The gids for this data source are the enumerated indices of the
flattened coordinate array.
knn : int
Number of nearest neighbors to aggregate from the base data when
comparing to a single site from the bias data.

Returns
-------
dist : np.ndarray
Array of nearest neighbor distances with length == knn
Array of nearest neighbor distances with length equal to the number
of high-resolution baseline gids that map to the low resolution
bias gid pixel.
base_gid : np.ndarray
Array of base gids that are the nearest neighbors of bias_gid with
length == knn
length equal to the number of high-resolution baseline gids that
map to the low resolution bias gid pixel.
"""
coord = self.bias_meta.loc[bias_gid, ['latitude', 'longitude']]
dist, base_gid = self.base_tree.query(coord, k=knn)
base_gid = np.where(self.nn_ind == bias_gid)[0]
dist = self.nn_dist[base_gid]
return dist, base_gid

def get_data_pair(self, coord, knn, daily_reduction='avg'):
def get_data_pair(self, coord, daily_reduction='avg'):
"""Get base and bias data observations based on a single bias gid.

Parameters
----------
coord : tuple
(lat, lon) to get data for.
knn : int
Number of nearest neighbors to aggregate from the base data when
comparing to a single site from the bias data.
daily_reduction : None | str
Option to do a reduction of the hourly+ source base data to daily
data. Can be None (no reduction, keep source time frequency), "avg"
Expand All @@ -287,12 +306,13 @@ def get_data_pair(self, coord, knn, daily_reduction='avg'):
1D array of temporal data at the requested gid.
base_dist : np.ndarray
Array of nearest neighbor distances from coord to the base data
sites with length == knn
sites with length equal to the number of high-resolution baseline
gids that map to the low resolution bias gid pixel.
bias_dist : Float
Nearest neighbor distance from coord to the bias data site
"""
bias_gid, bias_dist = self.get_bias_gid(coord)
base_dist, base_gid = self.get_base_gid(bias_gid, knn)
base_dist, base_gid = self.get_base_gid(bias_gid)
bias_data = self.get_bias_data(bias_gid)
base_data = self.get_base_data(self.base_fps,
self.base_dset,
Expand Down Expand Up @@ -651,6 +671,10 @@ def fill_and_smooth(self,
like: bias_data * scalar + adder. Each value is of shape
(lat, lon, time).
"""
if len(self.bad_bias_gids) > 0:
logger.info('Found {} bias gids that are out of bounds: {}'
.format(len(self.bad_bias_gids), self.bad_bias_gids))

for key, arr in out.items():
nan_mask = np.isnan(arr[..., 0])
for idt in range(arr.shape[-1]):
Expand All @@ -661,6 +685,9 @@ def fill_and_smooth(self,
and fill_extend) or smooth_interior > 0

if needs_fill:
logger.info('Filling NaN values outside of valid spatial '
'extent for dataset "{}" for timestep {}'
.format(key, idt))
arr_smooth = nn_fill_array(arr_smooth)

arr_smooth_int = arr_smooth_ext = arr_smooth
Expand Down Expand Up @@ -714,8 +741,6 @@ def write_outputs(self, fp_out, out):
'Wrote scalar adder factors to file: {}'.format(fp_out))

def run(self,
knn,
threshold=0.6,
fp_out=None,
max_workers=None,
daily_reduction='avg',
Expand All @@ -727,14 +752,6 @@ def run(self,

Parameters
----------
knn : int
Number of nearest neighbors to aggregate from the base data when
comparing to a single site from the bias data.
threshold : float
If the bias data coordinate is on average further from the base
data coordinates than this threshold, no bias correction factors
will be calculated directly and will just be filled from nearest
neighbor (if fill_extend=True, else it will be nan).
fp_out : str | None
Optional .h5 output file to write scalar and adder arrays.
max_workers : int
Expand All @@ -745,13 +762,13 @@ def run(self,
data. Can be None (no reduction, keep source time frequency), "avg"
(daily average), "max" (daily max), or "min" (daily min)
fill_extend : bool
Flag to fill data past threshold using spatial nearest neighbor. If
False, the extended domain will be left as NaN.
Flag to fill data past distance_upper_bound using spatial nearest
neighbor. If False, the extended domain will be left as NaN.
smooth_extend : float
Option to smooth the scalar/adder data outside of the spatial
domain set by the threshold input. This alleviates the weird seams
far from the domain of interest. This value is the standard
deviation for the gaussian_filter kernel
domain set by the distance_upper_bound input. This alleviates the
weird seams far from the domain of interest. This value is the
standard deviation for the gaussian_filter kernel
smooth_interior : float
Option to smooth the scalar/adder data within the valid spatial
domain. This can reduce the affect of extreme values within
Expand All @@ -770,14 +787,17 @@ def run(self,
logger.info('Initialized scalar / adder with shape: {}'
.format(self.bias_gid_raster.shape))

self.bad_bias_gids = []

if max_workers == 1:
logger.debug('Running serial calculation.')
for i, (bias_gid, row) in enumerate(self.bias_meta.iterrows()):
for i, bias_gid in enumerate(self.bias_meta.index):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
coord = row[['latitude', 'longitude']]
dist, base_gid = self.base_tree.query(coord, k=knn)
dist, base_gid = self.get_base_gid(bias_gid)

if np.mean(dist) < threshold:
if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
single_out = self._run_single(
bias_data,
Expand All @@ -804,12 +824,12 @@ def run(self,
futures = {}
for bias_gid, bias_row in self.bias_meta.iterrows():
raster_loc = np.where(self.bias_gid_raster == bias_gid)
coord = bias_row[['latitude', 'longitude']]
dist, base_gid = self.base_tree.query(coord, k=knn)
dist, base_gid = self.get_base_gid(bias_gid)

if np.mean(dist) < threshold:
if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)

future = exe.submit(
self._run_single,
bias_data,
Expand Down