Skip to content

Commit

Permalink
some arg cleaning in era_downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Apr 30, 2024
1 parent bd972c6 commit 6ffc4fa
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"pytest>=5.2",
"scipy>=1.0.0",
"sphinx>=7.0",
"tensorflow>2.4,<2.16",
"tensorflow>2.4,<2.10",
"xarray>=2023.0",
]

Expand Down
14 changes: 0 additions & 14 deletions requirements.txt

This file was deleted.

77 changes: 55 additions & 22 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def __init__(self,
run_interp=True,
overwrite=False,
variables=None,
check_files=False):
check_files=False,
product_type='reanalysis'):
"""Initialize the class.
Parameters
Expand Down Expand Up @@ -123,6 +124,9 @@ def __init__(self,
and wind components.
check_files : bool
Check existing files. Remove and redownload if checks fail.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
"""
self.year = year
self.month = month
Expand All @@ -141,12 +145,23 @@ def __init__(self,
self.sfc_file_variables = ['geopotential']
self.level_file_variables = ['geopotential']
self.prep_var_lists(self.variables)
self.product_type = product_type
self.hours = self.get_hours()

msg = ('Initialized EraDownloader with: '
f'year={self.year}, month={self.month}, area={self.area}, '
f'levels={self.levels}, variables={self.variables}')
logger.info(msg)

def get_hours(self):
"""ERA5 is hourly and EDA is 3-hourly. Check and warn for incompatible
requests."""
if self.product_type == 'reanalysis':
hours = [str(n).zfill(2) + ":00" for n in range(0, 24)]
else:
hours = [str(n).zfill(2) + ":00" for n in range(0, 24, 3)]
return hours

@property
def variables(self):
"""Get list of requested variables"""
Expand Down Expand Up @@ -275,18 +290,20 @@ def download_process_combine(self):
if sfc_check:
self.download_file(self.sfc_file_variables, time_dict=time_dict,
area=self.area, out_file=self.surface_file,
level_type='single', overwrite=self.overwrite)
level_type='single', overwrite=self.overwrite,
product_type=self.product_type)
if level_check:
self.download_file(self.level_file_variables, time_dict=time_dict,
area=self.area, out_file=self.level_file,
level_type='pressure', levels=self.levels,
overwrite=self.overwrite)
overwrite=self.overwrite,
product_type=self.product_type)
if sfc_check or level_check:
self.process_and_combine()

@classmethod
def download_file(cls, variables, time_dict, area, out_file, level_type,
levels=None, overwrite=False):
levels=None, product_type='reanalysis', overwrite=False):
"""Download either single-level or pressure-level file
Parameters
Expand All @@ -304,6 +321,9 @@ def download_file(cls, variables, time_dict, area, out_file, level_type,
Either 'single' or 'pressure'
levels : list
List of pressure levels to download, if level_type == 'pressure'
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
overwrite : bool
Whether to overwrite existing file
"""
Expand Down Expand Up @@ -555,7 +575,7 @@ def run_interpolation(self, max_workers=None, **kwargs):
overwrite=self.overwrite,
**kwargs)

def get_monthly_file(self, interp_workers=None, prune_variables=None,
def get_monthly_file(self, interp_workers=None, prune_variables=False,
**interp_kwargs):
"""Download level and surface files, process variables, and combine
processed files. Includes checks for shape and variables and option to
Expand Down Expand Up @@ -604,8 +624,8 @@ def all_months_exist(cls, year, file_pattern):
@classmethod
def already_pruned(cls, infile, prune_variables):
"""Check if file has been pruned already."""
if prune_variables is None:
logger.info('Received prune_variables=None. Skipping pruning.')
if not prune_variables:
logger.info('Received prune_variables=False. Skipping pruning.')
return
with xr.open_dataset(infile) as ds:
check_variables = [var for var in ds.data_vars
Expand All @@ -614,15 +634,15 @@ def already_pruned(cls, infile, prune_variables):
return pruned

@classmethod
def prune_output(cls, infile, prune_variables=None):
def prune_output(cls, infile, prune_variables=False):
"""Prune output file to keep just single level variables"""
if prune_variables is None:
logger.info('Received prune_variables=None. Skipping pruning.')
if not prune_variables:
logger.info('Received prune_variables=False. Skipping pruning.')
return
else:
logger.info(f'Pruning {infile}.')
tmp_file = cls.get_tmp_file(infile)
with xr.Dataset(infile) as ds:
with xr.open_dataset(infile) as ds:
keep_vars = {k: v for k, v in dict(ds.data_vars)
if 'level' not in ds[k].dims}
new_coords = {k: v for k, v in dict(ds.coords).items()
Expand All @@ -645,8 +665,9 @@ def run_month(cls,
overwrite=False,
interp_workers=None,
variables=None,
prune_variables=None,
prune_variables=False,
check_files=False,
product_type='reanalysis',
**interp_kwargs):
"""Run routine for all months in the requested year.
Expand Down Expand Up @@ -676,13 +697,16 @@ def run_month(cls,
variables : list | None
Variables to download. If None this defaults to just gepotential
and wind components.
prune_variables : bool | None
Variables to remove from final files. This is usually the multi
pressure level array of a variable which has since been
interpolated to specific heights.
pruned.
prune_variables : bool
Whether to remove 4D variables from data after interpolation. e.g.
height interpolation could give u_10m, u_100m, u_120m from a 4D u
array. If we only need these heights we could remove the 4D u array
from the final data file.
check_files : bool
Check existing files. Remove and redownload if checks fail.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
**interp_kwargs : dict
Keyword args for LogLinInterpolator.run()
"""
Expand All @@ -695,7 +719,8 @@ def run_month(cls,
run_interp=run_interp,
overwrite=overwrite,
variables=variables,
check_files=check_files)
check_files=check_files,
product_type=product_type)
downloader.get_monthly_file(interp_workers=interp_workers,
prune_variables=prune_variables,
**interp_kwargs)
Expand All @@ -714,8 +739,9 @@ def run_year(cls,
max_workers=None,
interp_workers=None,
variables=None,
prune_variables=None,
prune_variables=False,
check_files=False,
product_type='reanalysis',
**interp_kwargs):
"""Run routine for all months in the requested year.
Expand Down Expand Up @@ -750,11 +776,16 @@ def run_year(cls,
variables : list | None
Variables to download. If None this defaults to just gepotential
and wind components.
prune_variables : list | None
Variables to keep in final files. All other variables will be
pruned.
prune_variables : bool
Whether to remove 4D variables from data after interpolation. e.g.
height interpolation could give u_10m, u_100m, u_120m from a 4D u
array. If we only need these heights we could remove the 4D u array
from the final data file.
check_files : bool
Check existing files. Remove and redownload if checks fail.
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
**interp_kwargs : dict
Keyword args for LogLinInterpolator.run()
"""
Expand All @@ -772,6 +803,7 @@ def run_year(cls,
variables=variables,
prune_variables=prune_variables,
check_files=check_files,
product_type=product_type,
**interp_kwargs)
else:
futures = {}
Expand All @@ -791,6 +823,7 @@ def run_year(cls,
prune_variables=prune_variables,
variables=variables,
check_files=check_files,
product_type=product_type,
**interp_kwargs)
futures[future] = {'year': year, 'month': month}
logger.info(f'Submitted future for year {year} and month '
Expand Down

0 comments on commit 6ffc4fa

Please sign in to comment.