Skip to content

Commit

Permalink
moved lr/hr normalization to separate methods in dual data handling
Browse files Browse the repository at this point in the history
  • Loading branch information
grantbuster committed Apr 8, 2024
1 parent 6bd1b24 commit dd85e63
Showing 1 changed file with 48 additions and 6 deletions.
54 changes: 48 additions & 6 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,32 +254,74 @@ def normalize(self, means=None, stds=None, max_workers=None):
if stds is None:
stds = self.stds

self._normalize_lr(means, stds)
self._normalize_hr(means, stds)

def _normalize_lr(self, means, stds):
"""Normalize the low-resolution data features including in the
low-res data handler
Note that self.lr_data is usually a unique regridded array but if
regridding was not performed then it is just a sliced *view* of
self.lr_dh.data and the super().normalize() operation will have applied
to that data already.
Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
"""

logger.info('Normalizing low resolution data features='
f'{self.lr_dh.features}')
super().normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)

if id(self.lr_dh.data) != id(self.lr_data.base):
# self.lr_data is usually a unique regridded array but if
# regridding was not performed then it is just a sliced view of
# self.lr_dh.data and the super().normalize() operation will have
# applied to that data already.
self.lr_dh.normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)
else:
self.lr_dh._is_normalized = True

def _normalize_hr(self, means, stds):
"""Normalize the high-resolution data features including in the
high-res data handler
Note that self.hr_data is usually just a sliced *view* of
self.hr_dh.data but if the *view* is broken then it will have to be
normalized too
Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
"""

logger.info('Normalizing high resolution data features='
f'{self.hr_dh.features}')
self.hr_dh.normalize(means=means, stds=stds,
features=self.hr_dh.features,
max_workers=self.hr_dh.norm_workers)

if id(self.hr_data.base) != id(self.hr_dh.data):
# self.hr_data is usually just a sliced view of self.hr_dh.data
# but if the view is broken then it will have to be normalized too
mean_arr = np.array([means[fn] for fn in self.hr_dh.features])
std_arr = np.array([stds[fn] for fn in self.hr_dh.features])
self.hr_data = (self.hr_data - mean_arr) / std_arr
Expand Down

0 comments on commit dd85e63

Please sign in to comment.