From 00b52cf7e4f6c627c10ca80865df30d575477437 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 25 Apr 2024 09:56:28 -0600 Subject: [PATCH 01/16] removing sole netCDF4 dependency. Can use xarray for everything --- pyproject.toml | 1 - sup3r/utilities/interpolate_log_profile.py | 62 ++++------------------ 2 files changed, 9 insertions(+), 54 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f35dfa25f..792a6d3d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dependencies = [ "google-auth-oauthlib==0.5.3", "matplotlib>=3.1", "numpy>=1.7.0", - "netCDF4==1.5.8", "pandas>=2.0", "pillow>=10.0", "pytest>=5.2", diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 5b854f1df..1660f94f3 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -13,7 +13,6 @@ import numpy as np import xarray as xr -from netCDF4 import Dataset from rex import init_logger from scipy.interpolate import interp1d from scipy.optimize import curve_fit @@ -177,61 +176,18 @@ def save_output(self): """Save interpolated data to outfile""" dirname = os.path.dirname(self.outfile) os.makedirs(dirname, exist_ok=True) - os.system(f'cp {self.infile} {self.outfile}') - ds = Dataset(self.outfile, 'a') logger.info(f'Creating {self.outfile}.') - for var, data in self.new_data.items(): - for i, height in enumerate(self.new_heights[var]): - name = f'{var}_{height}m' - logger.info(f'Adding {name} to {self.outfile}.') - if name not in ds.variables: - _ = ds.createVariable( - name, - np.float32, - dimensions=('time', 'latitude', 'longitude'), - ) - ds.variables[name][:] = data[i, ...] - ds.variables[name].long_name = f'{height} meter {var}' - - units = None - if 'u_' in var or 'v_' in var: - units = 'm s**-1' - if 'pressure' in var: - units = 'Pa' - if 'temperature' in var: - units = 'C' - - if units is not None: - ds.variables[name].units = units - - ds.close() + with xr.open_dataset(self.infile) as ds: + for var, data in self.new_data.items(): + for i, height in enumerate(self.new_heights[var]): + name = f'{var}_{height}m' + logger.info(f'Adding {name} to {self.outfile}.') + if name not in ds.data_vars: + ds[name] = (('time', 'latitude', 'longitude'), data) + + ds.to_netcdf(self.outfile) logger.info(f'Saved interpolated output to {self.outfile}.') - @classmethod - def init_dims(cls, old_ds, new_ds, dims): - """Initialize dimensions in new dataset from old dataset - - Parameters - ---------- - old_ds : Dataset - Dataset() object from old file - new_ds : Dataset - Dataset() object for new file - dims : tuple - Tuple of dimensions. e.g. ('time', 'latitude', 'longitude') - - Returns - ------- - new_ds : Dataset - Dataset() object for new file with dimensions initialized. - """ - for var in dims: - new_ds.createDimension(var, len(old_ds[var])) - _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var) - new_ds[var][:] = old_ds[var][:] - new_ds[var].units = old_ds[var].units - return new_ds - @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be From a4a59922ce4fef71837d550463b612833b70f5e7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 4 Apr 2024 12:17:40 -0600 Subject: [PATCH 02/16] netcdf4 dep removal from era downloader --- sup3r/utilities/era_downloader.py | 206 ++++++++++++------------------ 1 file changed, 82 insertions(+), 124 deletions(-) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 107fe4fb3..e6dbaed03 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -20,7 +20,6 @@ import numpy as np import pandas as pd import xarray as xr -from netCDF4 import Dataset from sup3r.utilities.interpolate_log_profile import LogLinInterpolator @@ -199,31 +198,6 @@ def level_file(self): basename += f'{str(self.month).zfill(2)}.nc' return os.path.join(basedir, basename) - @classmethod - def init_dims(cls, old_ds, new_ds, dims): - """Initialize dimensions in new dataset from old dataset - - Parameters - ---------- - old_ds : Dataset - Dataset() object from old file - new_ds : Dataset - Dataset() object for new file - dims : tuple - Tuple of dimensions. e.g. ('time', 'latitude', 'longitude') - - Returns - ------- - new_ds : Dataset - Dataset() object for new file with dimensions initialized. - """ - for var in dims: - new_ds.createDimension(var, len(old_ds[var])) - _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var) - new_ds[var][:] = old_ds[var][:] - new_ds[var].units = old_ds[var].units - return new_ds - @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be @@ -355,100 +329,107 @@ def download_file(cls, variables, time_dict, area, out_file, level_type, def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" - dims = ('time', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.surface_file) - with Dataset(self.surface_file, "r") as old_ds: - with Dataset(tmp_file, "w") as ds: - ds = self.init_dims(old_ds, ds, dims) - - ds = self.convert_z('orog', 'Orography', old_ds, ds) - - ds = self.map_vars(old_ds, ds) + with xr.open_dataset(self.surface_file) as ds: + ds = self.convert_z(ds, name='orog') + ds = self.map_vars(ds) + ds.to_netcdf(tmp_file) os.system(f'mv {tmp_file} {self.surface_file}') logger.info(f'Finished processing {self.surface_file}. Moved ' f'{tmp_file} to {self.surface_file}.') - def map_vars(self, old_ds, ds): + def map_vars(self, ds): """Map variables from old dataset to new dataset Parameters ---------- - old_ds : Dataset - Dataset() object from old file ds : Dataset - Dataset() object for new file + xr.Dataset() object for which to rename variables Returns ------- ds : Dataset - Dataset() object for new file with new variables written. + xr.Dataset() object with new variables written. """ - for old_name in old_ds.variables: + for old_name in ds.data_vars: new_name = self.NAME_MAP.get(old_name, old_name) - if new_name not in ds.variables: - _ = ds.createVariable(new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) - vals = old_ds.variables[old_name][:] - if 'temperature' in new_name: - vals -= 273.15 - ds.variables[new_name][:] = vals + ds.rename({old_name: new_name}) + if 'temperature' in new_name: + ds[new_name] = (ds[new_name].dims, + ds[new_name].values - 273.15) return ds - def convert_z(self, standard_name, long_name, old_ds, ds): - """Convert z to given height variable + def shift_temp(self, ds): + """Shift temperature to celsius Parameters ---------- - standard_name : str - New variable name. e.g. 'zg' or 'orog' - long_name : str - Long name for new variable. e.g. 'Geopotential Height' or - 'Orography' - old_ds : Dataset - Dataset() object from tmp file ds : Dataset - Dataset() object for new file + xr.Dataset() object for which to shift temperature Returns ------- ds : Dataset - Dataset() object for new file with new height variable written. """ - _ = ds.createVariable(standard_name, - np.float32, - dimensions=old_ds['z'].dimensions) - ds.variables[standard_name][:] = old_ds['z'][:] / 9.81 - ds.variables[standard_name].long_name = long_name - ds.variables[standard_name].standard_name = 'zg' - ds.variables[standard_name].units = 'm' + for var in ds.data_vars: + if 'temperature' in var: + ds[var] = (ds[var].dims, ds[var].values - 273.15) return ds - def process_level_file(self): - """Convert geopotential to geopotential height.""" - dims = ('time', 'level', 'latitude', 'longitude') - tmp_file = self.get_tmp_file(self.level_file) - with Dataset(self.level_file, "r") as old_ds: - with Dataset(tmp_file, "w") as ds: - ds = self.init_dims(old_ds, ds, dims) + def add_pressure(self, ds): + """Add pressure to dataset - ds = self.convert_z('zg', 'Geopotential Height', old_ds, ds) + Parameters + ---------- + ds : Dataset + xr.Dataset() object for which to add pressure - ds = self.map_vars(old_ds, ds) + Returns + ------- + ds : Dataset + """ + if ('pressure' in self.variables + and 'pressure' not in ds.data_vars): + tmp = np.zeros(ds['zg'].shape) + + if 'number' in ds.dimensions: + tmp[:] = 100 * ds['level'].values[ + None, None, :, None, None] + else: + tmp[:] = 100 * ds['level'].values[ + None, :, None, None] + + ds['pressure'] = (ds['zg'].dims, tmp) + return ds + + def convert_z(self, ds, name): + """Convert z to given height variable + + Parameters + ---------- + ds : Dataset + xr.Dataset() object for new file + name : str + Variable name. e.g. zg or orog, typically - if ('pressure' in self.variables - and 'pressure' not in ds.variables): - tmp = np.zeros(ds.variables['zg'].shape) - for i in range(tmp.shape[1]): - tmp[:, i, :, :] = ds.variables['level'][i] * 100 + Returns + ------- + ds : Dataset + xr.Dataset() object for new file with new height variable written. + """ + ds['z'] = (ds['z'].dims, ds['z'].values / 9.81) + ds.rename({'z': name}) + return ds - _ = ds.createVariable('pressure', - np.float32, - dimensions=dims) - ds.variables['pressure'][:] = tmp[...] - ds.variables['pressure'].long_name = 'Pressure' - ds.variables['pressure'].units = 'Pa' + def process_level_file(self): + """Convert geopotential to geopotential height.""" + tmp_file = self.get_tmp_file(self.level_file) + with xr.open_dataset(self.level_file) as ds: + ds = self.convert_z(ds, name='zg') + ds = self.map_vars(ds) + ds = self.shift_temp(ds) + ds = self.add_pressure(ds) + ds.to_netcdf(tmp_file) os.system(f'mv {tmp_file} {self.level_file}') logger.info(f'Finished processing {self.level_file}. Moved ' @@ -629,17 +610,10 @@ def already_pruned(cls, infile, prune_variables): if prune_variables is None: logger.info('Received prune_variables=None. Skipping pruning.') return - else: - logger.info(f'Received prune_variables={prune_variables}.') - - pruned = True - with Dataset(infile, 'r') as ds: - variables = [var for var in ds.variables - if var not in ('time', 'latitude', 'longitude')] - for var in variables: - if not any(name in var for name in prune_variables): - logger.info(f'Pruning {var} in {infile}.') - pruned = False + with xr.open_dataset(infile) as ds: + check_variables = [var for var in ds.data_vars + if 'level' in ds[var].dims] + pruned = len(check_variables) == 0 return pruned @classmethod @@ -649,32 +623,16 @@ def prune_output(cls, infile, prune_variables=None): logger.info('Received prune_variables=None. Skipping pruning.') return else: - logger.info(f'Received prune_variables={prune_variables}.') - - logger.info(f'Pruning {infile}.') - tmp_file = cls.get_tmp_file(infile) - with Dataset(infile, 'r') as old_ds: - keep_vars = [var for var in old_ds.variables - if var not in prune_variables and var not - in ('time', 'latitude', 'longitude', 'level')] - with Dataset(tmp_file, 'w') as new_ds: - new_ds = cls.init_dims(old_ds, new_ds, - ('time', 'latitude', 'longitude')) - for var in keep_vars: - old_var = old_ds[var] - vals = old_var[:] - logger.info(f'Creating variable {var}.') - _ = new_ds.createVariable( - var, old_var.dtype, dimensions=old_var.dimensions) - new_ds[var][:] = vals - if hasattr(old_var, 'units'): - new_ds[var].units = old_var.units - if hasattr(old_var, 'standard_name'): - standard_name = old_var.standard_name - new_ds[var].standard_name = standard_name - if hasattr(old_var, 'long_name'): - new_ds[var].long_name = old_var.long_name - os.system(f'mv {tmp_file} {infile}') + logger.info(f'Pruning {infile}.') + tmp_file = cls.get_tmp_file(infile) + with xr.Dataset(infile) as ds: + keep_vars = {k:v for k, v in dict(ds.data_vars) + if 'level' not in ds[k].dims} + new_coords = {k:v for k, v in dict(ds.coords).items() + if 'level' not in k} + new_ds = xr.Dataset(coords=new_coords, data_vars=keep_vars) + new_ds.to_netcdf(tmp_file) + os.system(f'mv {tmp_file} {infile}') logger.info(f'Finished pruning variables in {infile}. Moved ' f'{tmp_file} to {infile}.') From 519f577bf00e99c5f8778df31973300cb5f841c9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 25 Apr 2024 10:05:56 -0600 Subject: [PATCH 03/16] unused var i --- sup3r/utilities/interpolate_log_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 1660f94f3..cef200de7 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -179,7 +179,7 @@ def save_output(self): logger.info(f'Creating {self.outfile}.') with xr.open_dataset(self.infile) as ds: for var, data in self.new_data.items(): - for i, height in enumerate(self.new_heights[var]): + for height in self.new_heights[var]: name = f'{var}_{height}m' logger.info(f'Adding {name} to {self.outfile}.') if name not in ds.data_vars: From ff057d6d071dca0745c34fc225aa1d3673831c4b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 25 Apr 2024 10:19:20 -0600 Subject: [PATCH 04/16] removed netcdf4 from version record --- sup3r/utilities/__init__.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sup3r/utilities/__init__.py b/sup3r/utilities/__init__.py index 5ed90b511..306b68550 100644 --- a/sup3r/utilities/__init__.py +++ b/sup3r/utilities/__init__.py @@ -1,20 +1,18 @@ """Sup3r utilities""" import sys -import pandas as pd -import numpy as np -import tensorflow as tf -import sklearn -import dask -import xarray -import netCDF4 from enum import Enum +import dask +import numpy as np +import pandas as pd import phygnn import rex +import sklearn +import tensorflow as tf +import xarray from sup3r import __version__ - VERSION_RECORD = {'sup3r': __version__, 'tensorflow': tf.__version__, 'sklearn': sklearn.__version__, @@ -24,7 +22,6 @@ 'nrel-rex': rex.__version__, 'python': sys.version, 'xarray': xarray.__version__, - 'netCDF4': netCDF4.__version__, 'dask': dask.__version__, } @@ -56,6 +53,7 @@ def __format__(self, format_spec): @classmethod def all_names(cls): """All module names. + Returns ------- set From 432fe55ba1798c5e6e6980af1c1449d55dbec5c4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 25 Apr 2024 10:24:57 -0600 Subject: [PATCH 05/16] uh. still need netcdf for xarray to read netcdf files. --- pyproject.toml | 1 + sup3r/models/base.py | 6 ++++-- sup3r/utilities/__init__.py | 2 ++ tests/bias/test_bias_correction.py | 10 +++++++--- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 792a6d3d1..185fb5aec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "NREL-farms>=1.0.4", "dask>=2022.0", "google-auth-oauthlib==0.5.3", + "h5netcdf", "matplotlib>=3.1", "numpy>=1.7.0", "pandas>=2.0", diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 2988e3314..b7ee1ce26 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -688,7 +688,8 @@ def train_epoch(self, if only_gen or (train_gen and not gen_too_good): trained_gen = True - b_loss_details = self.run_gradient_descent( + b_loss_details = self.timer( + self.run_gradient_descent, batch.low_res, batch.high_res, self.generator_weights, @@ -700,7 +701,8 @@ def train_epoch(self, if only_disc or (train_disc and not disc_too_good): trained_disc = True - b_loss_details = self.run_gradient_descent( + b_loss_details = self.timer( + self.run_gradient_descent, batch.low_res, batch.high_res, self.discriminator_weights, diff --git a/sup3r/utilities/__init__.py b/sup3r/utilities/__init__.py index 306b68550..bf12ba719 100644 --- a/sup3r/utilities/__init__.py +++ b/sup3r/utilities/__init__.py @@ -3,6 +3,7 @@ from enum import Enum import dask +import h5netcdf import numpy as np import pandas as pd import phygnn @@ -22,6 +23,7 @@ 'nrel-rex': rex.__version__, 'python': sys.version, 'xarray': xarray.__version__, + 'h5netcdf': h5netcdf.__version__, 'dask': dask.__version__, } diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 76eae6642..97b9e4cb4 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -11,8 +11,11 @@ from scipy import stats from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.bias.bias_calc import (LinearCorrection, MonthlyLinearCorrection, - SkillAssessment) +from sup3r.bias.bias_calc import ( + LinearCorrection, + MonthlyLinearCorrection, + SkillAssessment, +) from sup3r.bias.bias_transforms import local_linear_bc, monthly_local_linear_bc from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy @@ -327,7 +330,8 @@ def test_fwp_integration(): os.path.join(TEST_DATA_DIR, 'zg_test.nc')] lat_lon = DataHandlerNCforCC(input_files, features=[], target=target, - shape=shape).lat_lon + shape=shape, + worker_kwargs={'max_workers': 1}).lat_lon Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) From 77a686570711fa5daf44f0193e50fa6b621e5e96 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 25 Apr 2024 10:24:57 -0600 Subject: [PATCH 06/16] uh. still need netcdf for xarray to read netcdf files. --- pyproject.toml | 2 ++ sup3r/models/base.py | 6 ++++-- sup3r/utilities/__init__.py | 2 ++ tests/bias/test_bias_correction.py | 10 +++++++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 792a6d3d1..5a82b19d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,8 @@ dependencies = [ "NREL-farms>=1.0.4", "dask>=2022.0", "google-auth-oauthlib==0.5.3", + "h5netcdf", + "cftime", "matplotlib>=3.1", "numpy>=1.7.0", "pandas>=2.0", diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 2988e3314..b7ee1ce26 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -688,7 +688,8 @@ def train_epoch(self, if only_gen or (train_gen and not gen_too_good): trained_gen = True - b_loss_details = self.run_gradient_descent( + b_loss_details = self.timer( + self.run_gradient_descent, batch.low_res, batch.high_res, self.generator_weights, @@ -700,7 +701,8 @@ def train_epoch(self, if only_disc or (train_disc and not disc_too_good): trained_disc = True - b_loss_details = self.run_gradient_descent( + b_loss_details = self.timer( + self.run_gradient_descent, batch.low_res, batch.high_res, self.discriminator_weights, diff --git a/sup3r/utilities/__init__.py b/sup3r/utilities/__init__.py index 306b68550..bf12ba719 100644 --- a/sup3r/utilities/__init__.py +++ b/sup3r/utilities/__init__.py @@ -3,6 +3,7 @@ from enum import Enum import dask +import h5netcdf import numpy as np import pandas as pd import phygnn @@ -22,6 +23,7 @@ 'nrel-rex': rex.__version__, 'python': sys.version, 'xarray': xarray.__version__, + 'h5netcdf': h5netcdf.__version__, 'dask': dask.__version__, } diff --git a/tests/bias/test_bias_correction.py b/tests/bias/test_bias_correction.py index 76eae6642..97b9e4cb4 100644 --- a/tests/bias/test_bias_correction.py +++ b/tests/bias/test_bias_correction.py @@ -11,8 +11,11 @@ from scipy import stats from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.bias.bias_calc import (LinearCorrection, MonthlyLinearCorrection, - SkillAssessment) +from sup3r.bias.bias_calc import ( + LinearCorrection, + MonthlyLinearCorrection, + SkillAssessment, +) from sup3r.bias.bias_transforms import local_linear_bc, monthly_local_linear_bc from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy @@ -327,7 +330,8 @@ def test_fwp_integration(): os.path.join(TEST_DATA_DIR, 'zg_test.nc')] lat_lon = DataHandlerNCforCC(input_files, features=[], target=target, - shape=shape).lat_lon + shape=shape, + worker_kwargs={'max_workers': 1}).lat_lon Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) From 48f5b70fa7133836b7d5e0b7701fd46e85507cb7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 26 Apr 2024 07:22:07 -0600 Subject: [PATCH 07/16] log interp data output index fix --- sup3r/utilities/interpolate_log_profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index cef200de7..92980b514 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -183,7 +183,7 @@ def save_output(self): name = f'{var}_{height}m' logger.info(f'Adding {name} to {self.outfile}.') if name not in ds.data_vars: - ds[name] = (('time', 'latitude', 'longitude'), data) + ds[name] = (('time', 'latitude', 'longitude'), data[0]) ds.to_netcdf(self.outfile) logger.info(f'Saved interpolated output to {self.outfile}.') From 0d862a63f3857cfb08c9c087f70eaa6813c4dae4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 26 Apr 2024 07:49:11 -0600 Subject: [PATCH 08/16] pypi workflow update - need fetch-depth 0 to get correct version --- .github/workflows/publish_to_pypi.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/publish_to_pypi.yml b/.github/workflows/publish_to_pypi.yml index d0b445ac2..448d966ea 100644 --- a/.github/workflows/publish_to_pypi.yml +++ b/.github/workflows/publish_to_pypi.yml @@ -13,6 +13,10 @@ jobs: id-token: write steps: - uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + fetch-depth: 0 + fetch-tags: true - name: Set up Python uses: actions/setup-python@v4 with: From 1a44fd861cce717e79bae8f155d87dd8de4a1ae1 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 29 Apr 2024 18:05:17 -0600 Subject: [PATCH 09/16] pr updates --- pyproject.toml | 4 +-- sup3r/utilities/era_downloader.py | 53 +++++++++++++++---------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5a82b19d2..707e96a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,8 +32,8 @@ dependencies = [ "NREL-farms>=1.0.4", "dask>=2022.0", "google-auth-oauthlib==0.5.3", - "h5netcdf", - "cftime", + "h5netcdf>=1.1.0", + "cftime>=1.6.2", "matplotlib>=3.1", "numpy>=1.7.0", "pandas>=2.0", diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index e6dbaed03..7c9d6824a 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -330,10 +330,10 @@ def download_file(cls, variables, time_dict, area, out_file, level_type, def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" tmp_file = self.get_tmp_file(self.surface_file) - with xr.open_dataset(self.surface_file) as ds: - ds = self.convert_z(ds, name='orog') - ds = self.map_vars(ds) - ds.to_netcdf(tmp_file) + with xr.open_dataset(self.surface_file, mode='a') as ds: + new_ds = self.convert_z(ds, name='orog') + new_ds = self.map_vars(new_ds) + new_ds.to_netcdf(tmp_file) os.system(f'mv {tmp_file} {self.surface_file}') logger.info(f'Finished processing {self.surface_file}. Moved ' f'{tmp_file} to {self.surface_file}.') @@ -348,15 +348,12 @@ def map_vars(self, ds): Returns ------- - ds : Dataset + new_ds : Dataset xr.Dataset() object with new variables written. """ for old_name in ds.data_vars: new_name = self.NAME_MAP.get(old_name, old_name) - ds.rename({old_name: new_name}) - if 'temperature' in new_name: - ds[new_name] = (ds[new_name].dims, - ds[new_name].values - 273.15) + ds = ds.rename({old_name: new_name}) return ds def shift_temp(self, ds): @@ -372,8 +369,9 @@ def shift_temp(self, ds): ds : Dataset """ for var in ds.data_vars: - if 'temperature' in var: + if 'units' in ds[var].attrs and ds[var].attrs['units'] == 'K': ds[var] = (ds[var].dims, ds[var].values - 273.15) + ds[var].attrs['units'] = 'C' return ds def add_pressure(self, ds): @@ -390,16 +388,14 @@ def add_pressure(self, ds): """ if ('pressure' in self.variables and 'pressure' not in ds.data_vars): - tmp = np.zeros(ds['zg'].shape) - - if 'number' in ds.dimensions: - tmp[:] = 100 * ds['level'].values[ - None, None, :, None, None] - else: - tmp[:] = 100 * ds['level'].values[ - None, :, None, None] - - ds['pressure'] = (ds['zg'].dims, tmp) + expand_axes = (0, 2, 3) + pres = np.zeros(ds['zg'].values.shape) + if 'number' in ds.dims: + expand_axes = (0, 1, 3, 4) + pres[:] = np.expand_dims(100 * ds['level'].values, + axis=expand_axes) + ds['pressure'] = (ds['zg'].dims, pres) + ds['pressure'].attrs['units'] = 'Pa' return ds def convert_z(self, ds, name): @@ -417,19 +413,20 @@ def convert_z(self, ds, name): ds : Dataset xr.Dataset() object for new file with new height variable written. """ - ds['z'] = (ds['z'].dims, ds['z'].values / 9.81) - ds.rename({'z': name}) + if name not in ds.data_vars: + ds['z'] = (ds['z'].dims, ds['z'].values / 9.81) + ds = ds.rename({'z': name}) return ds def process_level_file(self): """Convert geopotential to geopotential height.""" tmp_file = self.get_tmp_file(self.level_file) - with xr.open_dataset(self.level_file) as ds: - ds = self.convert_z(ds, name='zg') - ds = self.map_vars(ds) - ds = self.shift_temp(ds) - ds = self.add_pressure(ds) - ds.to_netcdf(tmp_file) + with xr.open_dataset(self.level_file, mode='a') as ds: + new_ds = self.convert_z(ds, name='zg') + new_ds = self.map_vars(new_ds) + new_ds = self.shift_temp(new_ds) + new_ds = self.add_pressure(new_ds) + new_ds.to_netcdf(tmp_file) os.system(f'mv {tmp_file} {self.level_file}') logger.info(f'Finished processing {self.level_file}. Moved ' From bd972c67abb064ebeb8015dcea3531b70f807e04 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 29 Apr 2024 18:28:35 -0600 Subject: [PATCH 10/16] linting --- sup3r/utilities/era_downloader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 7c9d6824a..abe725655 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -623,9 +623,9 @@ def prune_output(cls, infile, prune_variables=None): logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) with xr.Dataset(infile) as ds: - keep_vars = {k:v for k, v in dict(ds.data_vars) + keep_vars = {k: v for k, v in dict(ds.data_vars) if 'level' not in ds[k].dims} - new_coords = {k:v for k, v in dict(ds.coords).items() + new_coords = {k: v for k, v in dict(ds.coords).items() if 'level' not in k} new_ds = xr.Dataset(coords=new_coords, data_vars=keep_vars) new_ds.to_netcdf(tmp_file) @@ -676,7 +676,7 @@ def run_month(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. - prune_variables : list | None + prune_variables : bool | None Variables to remove from final files. This is usually the multi pressure level array of a variable which has since been interpolated to specific heights. From 6ffc4facb67060bb567387863af54049aa1ef992 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 4 Apr 2024 06:15:38 -0600 Subject: [PATCH 11/16] some arg cleaning in era_downloader --- pyproject.toml | 2 +- requirements.txt | 14 ------ sup3r/utilities/era_downloader.py | 77 ++++++++++++++++++++++--------- 3 files changed, 56 insertions(+), 37 deletions(-) delete mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index 707e96a33..3db5e6812 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "pytest>=5.2", "scipy>=1.0.0", "sphinx>=7.0", - "tensorflow>2.4,<2.16", + "tensorflow>2.4,<2.10", "xarray>=2023.0", ] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 0a9d02624..000000000 --- a/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -matplotlib>=3.1 -NREL-rex>=0.2.84 -NREL-phygnn>=0.0.23 -NREL-gaps>=0.6.0 -NREL-farms>=1.0.4 -google-auth-oauthlib==0.5.3 -pytest>=5.2 -pillow>=10.0 -tensorflow>2.4,<2.16 -xarray>=2023.0 -netCDF4==1.5.8 -dask>=2022.0 -sphinx>=7.0 -pandas>=2.0 diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index abe725655..455da426e 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -94,7 +94,8 @@ def __init__(self, run_interp=True, overwrite=False, variables=None, - check_files=False): + check_files=False, + product_type='reanalysis'): """Initialize the class. Parameters @@ -123,6 +124,9 @@ def __init__(self, and wind components. check_files : bool Check existing files. Remove and redownload if checks fail. + product_type : str + Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread', + 'ensemble_members' """ self.year = year self.month = month @@ -141,12 +145,23 @@ def __init__(self, self.sfc_file_variables = ['geopotential'] self.level_file_variables = ['geopotential'] self.prep_var_lists(self.variables) + self.product_type = product_type + self.hours = self.get_hours() msg = ('Initialized EraDownloader with: ' f'year={self.year}, month={self.month}, area={self.area}, ' f'levels={self.levels}, variables={self.variables}') logger.info(msg) + def get_hours(self): + """ERA5 is hourly and EDA is 3-hourly. Check and warn for incompatible + requests.""" + if self.product_type == 'reanalysis': + hours = [str(n).zfill(2) + ":00" for n in range(0, 24)] + else: + hours = [str(n).zfill(2) + ":00" for n in range(0, 24, 3)] + return hours + @property def variables(self): """Get list of requested variables""" @@ -275,18 +290,20 @@ def download_process_combine(self): if sfc_check: self.download_file(self.sfc_file_variables, time_dict=time_dict, area=self.area, out_file=self.surface_file, - level_type='single', overwrite=self.overwrite) + level_type='single', overwrite=self.overwrite, + product_type=self.product_type) if level_check: self.download_file(self.level_file_variables, time_dict=time_dict, area=self.area, out_file=self.level_file, level_type='pressure', levels=self.levels, - overwrite=self.overwrite) + overwrite=self.overwrite, + product_type=self.product_type) if sfc_check or level_check: self.process_and_combine() @classmethod def download_file(cls, variables, time_dict, area, out_file, level_type, - levels=None, overwrite=False): + levels=None, product_type='reanalysis', overwrite=False): """Download either single-level or pressure-level file Parameters @@ -304,6 +321,9 @@ def download_file(cls, variables, time_dict, area, out_file, level_type, Either 'single' or 'pressure' levels : list List of pressure levels to download, if level_type == 'pressure' + product_type : str + Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread', + 'ensemble_members' overwrite : bool Whether to overwrite existing file """ @@ -555,7 +575,7 @@ def run_interpolation(self, max_workers=None, **kwargs): overwrite=self.overwrite, **kwargs) - def get_monthly_file(self, interp_workers=None, prune_variables=None, + def get_monthly_file(self, interp_workers=None, prune_variables=False, **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to @@ -604,8 +624,8 @@ def all_months_exist(cls, year, file_pattern): @classmethod def already_pruned(cls, infile, prune_variables): """Check if file has been pruned already.""" - if prune_variables is None: - logger.info('Received prune_variables=None. Skipping pruning.') + if not prune_variables: + logger.info('Received prune_variables=False. Skipping pruning.') return with xr.open_dataset(infile) as ds: check_variables = [var for var in ds.data_vars @@ -614,15 +634,15 @@ def already_pruned(cls, infile, prune_variables): return pruned @classmethod - def prune_output(cls, infile, prune_variables=None): + def prune_output(cls, infile, prune_variables=False): """Prune output file to keep just single level variables""" - if prune_variables is None: - logger.info('Received prune_variables=None. Skipping pruning.') + if not prune_variables: + logger.info('Received prune_variables=False. Skipping pruning.') return else: logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) - with xr.Dataset(infile) as ds: + with xr.open_dataset(infile) as ds: keep_vars = {k: v for k, v in dict(ds.data_vars) if 'level' not in ds[k].dims} new_coords = {k: v for k, v in dict(ds.coords).items() @@ -645,8 +665,9 @@ def run_month(cls, overwrite=False, interp_workers=None, variables=None, - prune_variables=None, + prune_variables=False, check_files=False, + product_type='reanalysis', **interp_kwargs): """Run routine for all months in the requested year. @@ -676,13 +697,16 @@ def run_month(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. - prune_variables : bool | None - Variables to remove from final files. This is usually the multi - pressure level array of a variable which has since been - interpolated to specific heights. - pruned. + prune_variables : bool + Whether to remove 4D variables from data after interpolation. e.g. + height interpolation could give u_10m, u_100m, u_120m from a 4D u + array. If we only need these heights we could remove the 4D u array + from the final data file. check_files : bool Check existing files. Remove and redownload if checks fail. + product_type : str + Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread', + 'ensemble_members' **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -695,7 +719,8 @@ def run_month(cls, run_interp=run_interp, overwrite=overwrite, variables=variables, - check_files=check_files) + check_files=check_files, + product_type=product_type) downloader.get_monthly_file(interp_workers=interp_workers, prune_variables=prune_variables, **interp_kwargs) @@ -714,8 +739,9 @@ def run_year(cls, max_workers=None, interp_workers=None, variables=None, - prune_variables=None, + prune_variables=False, check_files=False, + product_type='reanalysis', **interp_kwargs): """Run routine for all months in the requested year. @@ -750,11 +776,16 @@ def run_year(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. - prune_variables : list | None - Variables to keep in final files. All other variables will be - pruned. + prune_variables : bool + Whether to remove 4D variables from data after interpolation. e.g. + height interpolation could give u_10m, u_100m, u_120m from a 4D u + array. If we only need these heights we could remove the 4D u array + from the final data file. check_files : bool Check existing files. Remove and redownload if checks fail. + product_type : str + Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread', + 'ensemble_members' **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -772,6 +803,7 @@ def run_year(cls, variables=variables, prune_variables=prune_variables, check_files=check_files, + product_type=product_type, **interp_kwargs) else: futures = {} @@ -791,6 +823,7 @@ def run_year(cls, prune_variables=prune_variables, variables=variables, check_files=check_files, + product_type=product_type, **interp_kwargs) futures[future] = {'year': year, 'month': month} logger.info(f'Submitted future for year {year} and month ' From 7bb248a9ef020eca12cce227bff6c8fcc97dbb25 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 30 Apr 2024 04:39:04 -0600 Subject: [PATCH 12/16] wrong max tf version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3db5e6812..707e96a33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dependencies = [ "pytest>=5.2", "scipy>=1.0.0", "sphinx>=7.0", - "tensorflow>2.4,<2.10", + "tensorflow>2.4,<2.16", "xarray>=2023.0", ] From f430ff8eccf2a4b3c915412ff4d11ef1ffc9d543 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 May 2024 09:28:01 -0600 Subject: [PATCH 13/16] adding expand_paths call to base data handler --- sup3r/bias/bias_calc.py | 43 ++++------------------ sup3r/preprocessing/data_handling/mixin.py | 17 +++------ sup3r/utilities/era_downloader.py | 13 ++++--- sup3r/utilities/regridder.py | 1 + sup3r/utilities/utilities.py | 29 +++++++++++++++ 5 files changed, 50 insertions(+), 53 deletions(-) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 76f43fced..098c8005c 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -8,19 +8,17 @@ import os from abc import abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed -from glob import glob -from pathlib import Path import h5py import numpy as np import pandas as pd import rex -from rex.utilities.fun_utils import get_fun_call_str from rex.utilities.bc_utils import ( + sample_q_invlog, sample_q_linear, sample_q_log, - sample_q_invlog, ) +from rex.utilities.fun_utils import get_fun_call_str from scipy import stats from scipy.ndimage import gaussian_filter from scipy.spatial import KDTree @@ -29,39 +27,11 @@ from sup3r.preprocessing.data_handling.base import DataHandler from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI -from sup3r.utilities.utilities import nn_fill_array +from sup3r.utilities.utilities import expand_paths, nn_fill_array logger = logging.getLogger(__name__) -def _expand_paths(fps): - """Expand path(s) - - Parameter - --------- - fps : str or pathlib.Path or any Sequence of those - One or multiple paths to file - - Returns - ------- - list[str] - A list of expanded unique and sorted paths as str - - Examples - -------- - >>> _expand_paths("myfile.h5") - - >>> _expand_paths(["myfile.h5", "*.hdf"]) - """ - if isinstance(fps, (str, Path)): - fps = (fps, ) - - out = [] - for f in fps: - out.extend(glob(f)) - return sorted(set(out)) - - class DataRetrievalBase: """Base class to handle data retrieval for the biased data and the baseline data @@ -163,8 +133,8 @@ class to be retrieved from the rex/sup3r library. If a self._distance_upper_bound = distance_upper_bound self.match_zero_rate = match_zero_rate - self.base_fps = _expand_paths(self.base_fps) - self.bias_fps = _expand_paths(self.bias_fps) + self.base_fps = expand_paths(self.base_fps) + self.bias_fps = expand_paths(self.bias_fps) base_sup3r_handler = getattr(sup3r.preprocessing.data_handling, base_handler, None) @@ -1224,6 +1194,7 @@ class QuantileDeltaMappingCorrection(DataRetrievalBase): :func:`~sup3r.bias.bias_transforms.local_qdm_bc` to actually correct a dataset. """ + def __init__(self, base_fps, bias_fps, @@ -1308,7 +1279,7 @@ def __init__(self, self.bias_fut_fps = bias_fut_fps - self.bias_fut_fps = _expand_paths(self.bias_fut_fps) + self.bias_fut_fps = expand_paths(self.bias_fut_fps) self.bias_fut_dh = self.bias_handler(self.bias_fut_fps, [self.bias_feature], diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 5c5a69136..091897b33 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -2,7 +2,6 @@ @author: bbenton """ -import glob import logging import os import pickle @@ -18,6 +17,7 @@ from sup3r.utilities.utilities import ( estimate_max_workers, + expand_paths, get_source_type, ignore_case_path_fetch, uniform_box_sampler, @@ -644,22 +644,15 @@ def file_paths(self, file_paths): ---------- file_paths : str | list A list of files to extract raster data from. Each file must have - the same number of timesteps. Can also pass a string with a - unix-style file path which will be passed through glob.glob + the same number of timesteps. Can also pass a string or list of + strings with a unix-style file path which will be passed through + glob.glob """ - self._file_paths = file_paths - if isinstance(self._file_paths, str): - if '*' in file_paths: - self._file_paths = glob.glob(self._file_paths) - else: - self._file_paths = [self._file_paths] - + self._file_paths = expand_paths(file_paths) msg = ('No valid files provided to DataHandler. ' f'Received file_paths={file_paths}. Aborting.') assert file_paths is not None and len(self._file_paths) > 0, msg - self._file_paths = sorted(self._file_paths) - @property def ti_workers(self): """Get max number of workers for computing time index""" diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 455da426e..8606ac769 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -731,7 +731,7 @@ def run_year(cls, area, levels, combined_out_pattern, - combined_yearly_file, + combined_yearly_file=None, interp_out_pattern=None, interp_yearly_file=None, run_interp=True, @@ -834,10 +834,13 @@ def run_year(cls, logger.info(f'Finished future for year {v["year"]} and month ' f'{v["month"]}.') - cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file) + if combined_yearly_file is not None: + cls.make_yearly_file(year, combined_out_pattern, + combined_yearly_file) - if run_interp: - cls.make_yearly_file(year, interp_out_pattern, interp_yearly_file) + if run_interp and interp_yearly_file is not None: + cls.make_yearly_file(year, interp_out_pattern, + interp_yearly_file) @classmethod def make_yearly_file(cls, year, file_pattern, yearly_file): @@ -863,7 +866,7 @@ def make_yearly_file(cls, year, file_pattern, yearly_file): ] if not os.path.exists(yearly_file): - with xr.open_mfdataset(files) as res: + with xr.open_mfdataset(files, parallel=True) as res: logger.info(f'Combining {files}') os.makedirs(os.path.dirname(yearly_file), exist_ok=True) res.to_netcdf(yearly_file) diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index fa086a230..c50eede76 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -199,6 +199,7 @@ def get_all_queries(self, max_workers=None): else: logger.info('Querying all coordinates in parallel.') self._parallel_queries(max_workers=max_workers) + logger.info('Finished querying all coordinates.') def _serial_queries(self): """Get indices and distances for all points in target_meta, in diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index b0ac20a62..0479c88f3 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -10,6 +10,7 @@ import string import time from fnmatch import fnmatch +from pathlib import Path from warnings import warn import numpy as np @@ -55,6 +56,34 @@ def __call__(self, fun, *args, **kwargs): return out +def expand_paths(fps): + """Expand path(s) + + Parameter + --------- + fps : str or pathlib.Path or any Sequence of those + One or multiple paths to file + + Returns + ------- + list[str] + A list of expanded unique and sorted paths as str + + Examples + -------- + >>> expand_paths("myfile.h5") + + >>> expand_paths(["myfile.h5", "*.hdf"]) + """ + if isinstance(fps, (str, Path)): + fps = (fps, ) + + out = [] + for f in fps: + out.extend(glob(f)) + return sorted(set(out)) + + def generate_random_string(length): """Generate random string with given length. Used for naming temporary files to avoid collisions.""" From ad43906f4f1df95666572dd74b2678707a573522 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 May 2024 09:41:38 -0600 Subject: [PATCH 14/16] glob -> glob.glob --- sup3r/utilities/utilities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 0479c88f3..3e8cc3f4e 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -80,7 +80,7 @@ def expand_paths(fps): out = [] for f in fps: - out.extend(glob(f)) + out.extend(glob.glob(f)) return sorted(set(out)) From f0840f242b103c31fb919bd821deb8be527aeb0f Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 2 May 2024 05:54:40 -0600 Subject: [PATCH 15/16] lr_padded_slice arg needed in local_qdm_bc method --- sup3r/bias/bias_transforms.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 09c31d3d6..a4bcb9173 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -403,6 +403,7 @@ def local_qdm_bc(data: np.array, base_dset: str, feature_name: str, bias_fp, + lr_padded_slice, threshold=0.1, relative=True, no_trend=False): @@ -433,6 +434,13 @@ def local_qdm_bc(data: np.array, "bias_fut_{feature_name}_params", and "base_{base_dset}_params" that are the parameters to define the statistical distributions to be used to correct the given `data`. + lr_padded_slice : tuple | None + Tuple of length four that slices (spatial_1, spatial_2, temporal, + features) where each tuple entry is a slice object for that axes. + Note that if this method is called as part of a sup3r forward pass, the + lr_padded_slice will be included automatically in the kwargs for the + active chunk. If this is None, no slicing will be done and the full + bias correction source shape will be used. no_trend: bool, default=False An option to ignore the trend component of the correction, thus resulting in an ordinary Quantile Mapping, i.e. corrects the bias by @@ -485,6 +493,11 @@ def local_qdm_bc(data: np.array, feature_name, bias_fp, threshold) + if lr_padded_slice is not None: + spatial_slice = (lr_padded_slice[0], lr_padded_slice[1]) + base = base[spatial_slice] + bias = bias[spatial_slice] + bias_fut = bias[spatial_slice] if no_trend: mf = None From 7c29d0b8f65958d6c224dfeecb5843f0739dfa69 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 2 May 2024 06:57:03 -0600 Subject: [PATCH 16/16] lr_padded_slice=None default for bias_transforms --- sup3r/bias/bias_transforms.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index a4bcb9173..153e35a1c 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -203,7 +203,7 @@ def local_linear_bc(input, lat_lon, feature_name, bias_fp, - lr_padded_slice, + lr_padded_slice=None, out_range=None, smoothing=0, ): @@ -292,8 +292,8 @@ def monthly_local_linear_bc(input, lat_lon, feature_name, bias_fp, - lr_padded_slice, time_index, + lr_padded_slice=None, temporal_avg=True, out_range=None, smoothing=0, @@ -318,6 +318,11 @@ def monthly_local_linear_bc(input, datasets "{feature_name}_scalar" and "{feature_name}_adder" that are the full low-resolution shape of the forward pass input that will be sliced using lr_padded_slice for the current chunk. + time_index : pd.DatetimeIndex + DatetimeIndex object associated with the input data temporal axis + (assumed 3rd axis e.g. axis=2). Note that if this method is called as + part of a sup3r resolution forward pass, the time_index will be + included automatically for the current chunk. lr_padded_slice : tuple | None Tuple of length four that slices (spatial_1, spatial_2, temporal, features) where each tuple entry is a slice object for that axes. @@ -325,11 +330,6 @@ def monthly_local_linear_bc(input, lr_padded_slice will be included automatically in the kwargs for the active chunk. If this is None, no slicing will be done and the full bias correction source shape will be used. - time_index : pd.DatetimeIndex - DatetimeIndex object associated with the input data temporal axis - (assumed 3rd axis e.g. axis=2). Note that if this method is called as - part of a sup3r resolution forward pass, the time_index will be - included automatically for the current chunk. temporal_avg : bool Take the average scalars and adders for the chunk's time index, this will smooth the transition of scalars/adders from month to month if @@ -403,7 +403,7 @@ def local_qdm_bc(data: np.array, base_dset: str, feature_name: str, bias_fp, - lr_padded_slice, + lr_padded_slice=None, threshold=0.1, relative=True, no_trend=False):