Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gb/bias nc #181

Merged
merged 15 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: topo extration class will now sub-grid aggregate high-res e…
…xo source data to low-res target grid instead of doing a k nearest neighbor
  • Loading branch information
grantbuster committed Dec 7, 2023
commit 948637d570fb1ce25d1578cf0973d13f8b56781f
4 changes: 2 additions & 2 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,7 @@ def run(self,
logger.debug('Running serial calculation.')
for i, bias_gid in enumerate(self.bias_meta.index):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
dist, base_gid = self.get_base_gid(bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
Expand Down Expand Up @@ -907,7 +907,7 @@ def run(self,
max_workers))
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {}
for bias_gid, bias_row in self.bias_meta.iterrows():
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
dist, base_gid = self.get_base_gid(bias_gid)

Expand Down
148 changes: 72 additions & 76 deletions sup3r/preprocessing/data_handling/exo_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pickle
import shutil
from abc import ABC, abstractmethod
from warnings import warn

import pandas as pd
import numpy as np
from rex import Resource
from rex.utilities.solar_position import SolarPosition
Expand All @@ -15,7 +17,8 @@
from sup3r.postprocessing.file_handling import OutputHandler
from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5
from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC
from sup3r.utilities.utilities import generate_random_string, get_source_type
from sup3r.utilities.utilities import (generate_random_string, get_source_type,
nn_fill_array)

logger = logging.getLogger(__name__)

Expand All @@ -32,7 +35,6 @@ def __init__(self,
exo_source,
s_enhance,
t_enhance,
s_agg_factor,
t_agg_factor,
target=None,
shape=None,
Expand All @@ -43,6 +45,7 @@ def __init__(self,
cache_data=True,
cache_dir='./exo_cache/',
ti_workers=1,
distance_upper_bound=None,
res_kwargs=None):
"""Parameters
----------
Expand All @@ -68,13 +71,6 @@ def __init__(self,
example, if getting sza data, file_paths has hourly data, and
t_enhance is 4, this class will output a sza raster
corresponding to the file_paths temporally enhanced 4x to 15 min
s_agg_factor : int
Factor by which to aggregate the exo_source data to the resolution
of the file_paths input enhanced by s_enhance. For example, if
getting topography data, file_paths have 100km data, and s_enhance
is 4 resulting in a desired resolution of ~25km and topo_source_h5
has a resolution of 4km, the s_agg_factor should be 36 so that 6x6
4km cells are averaged to the ~25km enhanced grid.
t_agg_factor : int
Factor by which to aggregate the exo_source data to the resolution
of the file_paths input enhanced by t_enhance. For example, if
Expand Down Expand Up @@ -118,6 +114,10 @@ def __init__(self,
parallel and then concatenated to get the full time index. If input
files do not all have time indices or if there are few input files
this should be set to one.
distance_upper_bound : float | None
Maximum distance to map high-resolution data from exo_source to the
low-resolution file_paths input. None (default) will calculate this
based on the median distance between points in exo_source
res_kwargs : dict | None
Dictionary of kwargs passed to lowest level resource handler. e.g.
xr.open_dataset(file_paths, **res_kwargs)
Expand All @@ -128,13 +128,13 @@ def __init__(self,
self._exo_source = exo_source
self._s_enhance = s_enhance
self._t_enhance = t_enhance
self._s_agg_factor = s_agg_factor
self._t_agg_factor = t_agg_factor
self._tree = None
self._hr_lat_lon = None
self._source_lat_lon = None
self._hr_time_index = None
self._src_time_index = None
self._distance_upper_bound = None
self.cache_data = cache_data
self.cache_dir = cache_dir
self.temporal_slice = temporal_slice
Expand Down Expand Up @@ -179,8 +179,7 @@ def __init__(self,
def source_data(self):
"""Get the 1D array of source data from the exo_source_h5"""

def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
t_agg_factor):
def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont we want the cache identified by spatial agg factor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I actually removed the s_agg_factor from the whole module altogether because for the topo aggregation "factor" it's now determined by the actual pixel mapping.

"""Get cache file name

Parameters
Expand All @@ -193,9 +192,6 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
t_enhance : int
Temporal enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
s_agg_factor : int
Factor by which to aggregate the exo_source data to the spatial
resolution of the file_paths input enhanced by s_enhance.
t_agg_factor : int
Factor by which to aggregate the exo_source data to the temporal
resolution of the file_paths input enhanced by t_enhance.
Expand All @@ -210,7 +206,7 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
or self.temporal_slice.stop is None
else self.temporal_slice.stop - self.temporal_slice.start)
fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}'
fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'{t_enhance}x.pkl'
fn = fn.replace('(', '').replace(')', '')
fn = fn.replace('[', '').replace(']', '')
Expand All @@ -233,14 +229,10 @@ def source_temporal_slice(self):

@property
def source_lat_lon(self):
"""Get the 2D array (n, 2) of lat, lon data for the exo source"""
if self._source_lat_lon is None:
src_enhance = int(np.sqrt(self._s_agg_factor))
src_shape = (self.hr_shape[0] * src_enhance,
self.hr_shape[1] * src_enhance)
self._source_lat_lon = OutputHandler.get_lat_lon(
self.lr_lat_lon, src_shape).reshape((-1, 2))
return self._source_lat_lon
"""Get the 2D array (n, 2) of lat, lon data from the exo_source_h5"""
with Resource(self._exo_source) as res:
source_lat_lon = res.lat_lon
return source_lat_lon

@property
def lr_shape(self):
Expand All @@ -259,7 +251,7 @@ def hr_shape(self):
def lr_lat_lon(self):
"""Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon
array with same ordering in last dimension. This corresponds to the raw
low-resolution meta data from the file_paths input.
meta data from the file_paths input.

Returns
-------
Expand All @@ -271,7 +263,7 @@ def lr_lat_lon(self):
def hr_lat_lon(self):
"""Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon
array with same ordering in last dimension. This corresponds to the
enhanced high-res meta data from the file_paths input * s_enhance.
enhanced meta data from the file_paths input * s_enhance.

Returns
-------
Expand Down Expand Up @@ -308,22 +300,32 @@ def hr_time_index(self):
self._hr_time_index = self.input_handler.time_index
return self._hr_time_index

@property
def distance_upper_bound(self):
"""Maximum distance (float) to map high-resolution data from exo_source
to the low-resolution file_paths input."""
if self._distance_upper_bound is None:
diff = np.diff(self.source_lat_lon, axis=0)
diff = np.max(np.median(diff, axis=0))
self._distance_upper_bound = diff
return self._distance_upper_bound

@property
def tree(self):
"""Get the KDTree built on the source lat lon data"""
"""Get the KDTree built on the target lat lon data from the file_paths
input with s_enhance"""
if self._tree is None:
self._tree = KDTree(self.source_lat_lon)
lat = self.hr_lat_lon[..., 0].flatten()
lon = self.hr_lat_lon[..., 1].flatten()
hr_meta = np.vstack((lat, lon)).T
self._tree = KDTree(hr_meta)
return self._tree

@property
def nn(self):
"""Get the nearest neighbor indices"""
ll2 = np.vstack(
(self.hr_lat_lon[:, :, 0].flatten(),
self.hr_lat_lon[:, :, 1].flatten())).T
_, nn = self.tree.query(ll2, k=self._s_agg_factor)
if len(nn.shape) == 1:
nn = np.expand_dims(nn, 1)
_, nn = self.tree.query(self.source_lat_lon, k=1,
distance_upper_bound=self.distance_upper_bound)
return nn

@property
Expand All @@ -335,7 +337,6 @@ def data(self):
cache_fp = self.get_cache_file(feature=self.__class__.__name__,
s_enhance=self._s_enhance,
t_enhance=self._t_enhance,
s_agg_factor=self._s_agg_factor,
t_agg_factor=self._t_agg_factor)
tmp_fp = cache_fp + f'.{generate_random_string(10)}.tmp'
if os.path.exists(cache_fp):
Expand All @@ -355,27 +356,18 @@ def data(self):

return data[..., np.newaxis]

@abstractmethod
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal, 1)
t_enhance). The shape is (lats, lons, temporal)
"""
nn = self.nn
hr_data = []
for j in range(self._s_agg_factor):
out = self.source_data[nn[:, j], self.source_temporal_slice]
out = out.reshape(self.hr_shape)
hr_data.append(out[..., np.newaxis])
hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1)
logger.info('Finished mapping raster from {}'.format(self._exo_source))
return hr_data

@classmethod
def get_exo_raster(cls,
file_paths,
s_enhance,
t_enhance,
s_agg_factor,
t_agg_factor,
exo_source=None,
target=None,
Expand Down Expand Up @@ -407,13 +399,6 @@ class will output a topography raster corresponding to the
example, if getting sza data, file_paths has hourly data, and
t_enhance is 4, this class will output a sza raster
corresponding to the file_paths temporally enhanced 4x to 15 min
s_agg_factor : int
Factor by which to aggregate the exo_source data to the resolution
of the file_paths input enhanced by s_enhance. For example, if
getting topography data, file_paths have 100km data, and s_enhance
is 4 resulting in a desired resolution of ~25km and topo_source_h5
has a resolution of 4km, the s_agg_factor should be 36 so that 6x6
4km cells are averaged to the ~25km enhanced grid.
t_agg_factor : int
Factor by which to aggregate the exo_source data to the resolution
of the file_paths input enhanced by t_enhance. For example, if
Expand Down Expand Up @@ -467,7 +452,6 @@ class will output a topography raster corresponding to the
exo = cls(file_paths,
s_enhance,
t_enhance,
s_agg_factor,
t_agg_factor,
exo_source=exo_source,
target=target,
Expand All @@ -491,13 +475,6 @@ def source_data(self):
elev = res.get_meta_arr('elevation')
return elev[:, np.newaxis]

@property
def source_lat_lon(self):
"""Get the 2D array (n, 2) of lat, lon data from the exo_source_h5"""
with Resource(self._exo_source) as res:
source_lat_lon = res.lat_lon
return source_lat_lon

@property
def source_time_index(self):
"""Time index of the source exo data"""
Expand All @@ -509,20 +486,42 @@ def source_time_index(self):
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal, 1)
t_enhance). The shape is (lats, lons, 1)
"""
nn = self.nn
hr_data = []
for j in range(self._s_agg_factor):
out = self.source_data[nn[:, j]]
out = out.reshape((*self.hr_shape[:-1], -1))
hr_data.append(out[..., np.newaxis])
hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1)

assert len(self.source_data.shape) == 2
assert self.source_data.shape[1] == 1

df = pd.DataFrame({'topo': self.source_data.flatten(),
'gid_target': self.nn})
n_target = np.product(self.hr_shape[:-1])
df = df[df['gid_target'] != n_target]
df = df.sort_values('gid_target')
df = df.groupby('gid_target').mean()

missing = set(np.arange(n_target)) - set(df.index)
if any(missing):
msg = (f'{len(missing)} target pixels did not have unique '
'high-resolution source data to map from. If there are a '
'lot of target pixels missing source data this probably '
'means the source data is not high enough resolution. '
'Filling raster with NN.')
logger.warning(msg)
warn(msg)
temp_df = pd.DataFrame({'topo': np.nan}, index=sorted(missing))
df = pd.concat((df, temp_df)).sort_index()

hr_data = df['topo'].values.reshape(self.hr_shape[:-1])
if np.isnan(hr_data).any():
hr_data = nn_fill_array(hr_data)

hr_data = np.expand_dims(hr_data, axis=-1)

logger.info('Finished mapping raster from {}'.format(self._exo_source))

return hr_data

def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
t_agg_factor):
def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor):
"""Get cache file name. This uses a time independent naming convention.

Parameters
Expand All @@ -535,9 +534,6 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
t_enhance : int
Temporal enhancement for this exogeneous data step (cumulative for
all model steps up to the current step).
s_agg_factor : int
Factor by which to aggregate the exo_source data to the spatial
resolution of the file_paths input enhanced by s_enhance.
t_agg_factor : int
Factor by which to aggregate the exo_source data to the temporal
resolution of the file_paths input enhanced by t_enhance.
Expand All @@ -548,7 +544,7 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor,
Name of cache file
"""
fn = f'exo_{feature}_{self.target}_{self.shape}'
fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'_tagg{t_agg_factor}_{s_enhance}x_'
fn += f'{t_enhance}x.pkl'
fn = fn.replace('(', '').replace(')', '')
fn = fn.replace('[', '').replace(']', '')
Expand Down Expand Up @@ -605,7 +601,7 @@ def source_data(self):
def get_data(self):
"""Get a raster of source values corresponding to the
high-resolution grid (the file_paths input grid * s_enhance *
t_enhance). The shape is (lats, lons, temporal, 1)
t_enhance). The shape is (lats, lons, temporal)
"""
hr_data = self.source_data.reshape(self.hr_shape)
logger.info('Finished computing SZA data')
Expand Down
5 changes: 2 additions & 3 deletions tests/bias/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,8 @@ def test_nc_base_file():
assert (calc.nn_dist == 0).all()

with pytest.raises(RuntimeError) as exc:
base_data, _ = calc.get_base_data(calc.base_fps, calc.base_dset,
base_gid, calc.base_handler,
daily_reduction='avg')
calc.get_base_data(calc.base_fps, calc.base_dset, base_gid,
calc.base_handler, daily_reduction='avg')

good_err = 'only to be used with `base_handler` as a `sup3r.DataHandler` '
assert good_err in str(exc.value)
Expand Down
Loading