Skip to content

Commit

Permalink
added kwarg to not regrid lr data in dual dh if you know the coords m…
Browse files Browse the repository at this point in the history
…atch
  • Loading branch information
grantbuster committed Apr 4, 2024
1 parent a446b80 commit 685acd3
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self,
regrid_workers=1,
load_cached=True,
shuffle_time=False,
regrid_lr=True,
s_enhance=1,
t_enhance=1,
val_split=0.0):
Expand All @@ -61,6 +62,10 @@ def __init__(self,
is called.
shuffle_time : bool
Whether to shuffle time indices prior to training/validation split
regrid_lr : bool
Flag to regrid the low-res handler data to the high-res handler
grid. This will take care of any minor inconsistencies in different
projections. Disable this if the grids are known to be the same.
s_enhance : int
Spatial enhancement factor
t_enhance : int
Expand Down Expand Up @@ -95,6 +100,7 @@ def __init__(self,
self._means = None
self._stds = None
self._is_normalized = False
self._regrid_lr = regrid_lr
self._norm_workers = self.lr_dh.norm_workers

if self.try_load and self.load_cached:
Expand Down Expand Up @@ -579,16 +585,19 @@ def get_lr_regridded_data(self):
"""Regrid low_res data for all requested noncached features. Load
cached features if available and overwrite=False"""

logger.info('Regridding low resolution feature data.')
regridder = self.get_regridder()
if self._regrid_lr:
logger.info('Regridding low resolution feature data.')
regridder = self.get_regridder()

fnames = set(self.noncached_features)
fnames = fnames.intersection(set(self.lr_dh.features))
for fname in fnames:
fidx = self.lr_dh.features.index(fname)
tmp = regridder(self.lr_input_data[..., fidx])
tmp = tmp.reshape(self.lr_required_shape)
self.lr_data[..., fidx] = tmp
fnames = set(self.noncached_features)
fnames = fnames.intersection(set(self.lr_dh.features))
for fname in fnames:
fidx = self.lr_dh.features.index(fname)
tmp = regridder(self.lr_input_data[..., fidx])
tmp = tmp.reshape(self.lr_required_shape)
self.lr_data[..., fidx] = tmp
else:
self.lr_data = self.lr_input_data

if self.load_cached:
fnames = set(self.cached_features)
Expand Down

0 comments on commit 685acd3

Please sign in to comment.