Skip to content

Commit

Permalink
netcdf4 dep removal from era downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Apr 25, 2024
1 parent 00b52cf commit a4a5992
Showing 1 changed file with 82 additions and 124 deletions.
206 changes: 82 additions & 124 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import numpy as np
import pandas as pd
import xarray as xr
from netCDF4 import Dataset

from sup3r.utilities.interpolate_log_profile import LogLinInterpolator

Expand Down Expand Up @@ -199,31 +198,6 @@ def level_file(self):
basename += f'{str(self.month).zfill(2)}.nc'
return os.path.join(basedir, basename)

@classmethod
def init_dims(cls, old_ds, new_ds, dims):
"""Initialize dimensions in new dataset from old dataset
Parameters
----------
old_ds : Dataset
Dataset() object from old file
new_ds : Dataset
Dataset() object for new file
dims : tuple
Tuple of dimensions. e.g. ('time', 'latitude', 'longitude')
Returns
-------
new_ds : Dataset
Dataset() object for new file with dimensions initialized.
"""
for var in dims:
new_ds.createDimension(var, len(old_ds[var]))
_ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var)
new_ds[var][:] = old_ds[var][:]
new_ds[var].units = old_ds[var].units
return new_ds

@classmethod
def get_tmp_file(cls, file):
"""Get temp file for given file. Then only needed variables will be
Expand Down Expand Up @@ -355,100 +329,107 @@ def download_file(cls, variables, time_dict, area, out_file, level_type,

def process_surface_file(self):
"""Rename variables and convert geopotential to geopotential height."""
dims = ('time', 'latitude', 'longitude')
tmp_file = self.get_tmp_file(self.surface_file)
with Dataset(self.surface_file, "r") as old_ds:
with Dataset(tmp_file, "w") as ds:
ds = self.init_dims(old_ds, ds, dims)

ds = self.convert_z('orog', 'Orography', old_ds, ds)

ds = self.map_vars(old_ds, ds)
with xr.open_dataset(self.surface_file) as ds:
ds = self.convert_z(ds, name='orog')
ds = self.map_vars(ds)
ds.to_netcdf(tmp_file)
os.system(f'mv {tmp_file} {self.surface_file}')
logger.info(f'Finished processing {self.surface_file}. Moved '
f'{tmp_file} to {self.surface_file}.')

def map_vars(self, old_ds, ds):
def map_vars(self, ds):
"""Map variables from old dataset to new dataset
Parameters
----------
old_ds : Dataset
Dataset() object from old file
ds : Dataset
Dataset() object for new file
xr.Dataset() object for which to rename variables
Returns
-------
ds : Dataset
Dataset() object for new file with new variables written.
xr.Dataset() object with new variables written.
"""
for old_name in old_ds.variables:
for old_name in ds.data_vars:
new_name = self.NAME_MAP.get(old_name, old_name)
if new_name not in ds.variables:
_ = ds.createVariable(new_name,
np.float32,
dimensions=old_ds[old_name].dimensions,
)
vals = old_ds.variables[old_name][:]
if 'temperature' in new_name:
vals -= 273.15
ds.variables[new_name][:] = vals
ds.rename({old_name: new_name})
if 'temperature' in new_name:
ds[new_name] = (ds[new_name].dims,
ds[new_name].values - 273.15)
return ds

def convert_z(self, standard_name, long_name, old_ds, ds):
"""Convert z to given height variable
def shift_temp(self, ds):
"""Shift temperature to celsius
Parameters
----------
standard_name : str
New variable name. e.g. 'zg' or 'orog'
long_name : str
Long name for new variable. e.g. 'Geopotential Height' or
'Orography'
old_ds : Dataset
Dataset() object from tmp file
ds : Dataset
Dataset() object for new file
xr.Dataset() object for which to shift temperature
Returns
-------
ds : Dataset
Dataset() object for new file with new height variable written.
"""
_ = ds.createVariable(standard_name,
np.float32,
dimensions=old_ds['z'].dimensions)
ds.variables[standard_name][:] = old_ds['z'][:] / 9.81
ds.variables[standard_name].long_name = long_name
ds.variables[standard_name].standard_name = 'zg'
ds.variables[standard_name].units = 'm'
for var in ds.data_vars:
if 'temperature' in var:
ds[var] = (ds[var].dims, ds[var].values - 273.15)
return ds

def process_level_file(self):
"""Convert geopotential to geopotential height."""
dims = ('time', 'level', 'latitude', 'longitude')
tmp_file = self.get_tmp_file(self.level_file)
with Dataset(self.level_file, "r") as old_ds:
with Dataset(tmp_file, "w") as ds:
ds = self.init_dims(old_ds, ds, dims)
def add_pressure(self, ds):
"""Add pressure to dataset
ds = self.convert_z('zg', 'Geopotential Height', old_ds, ds)
Parameters
----------
ds : Dataset
xr.Dataset() object for which to add pressure
ds = self.map_vars(old_ds, ds)
Returns
-------
ds : Dataset
"""
if ('pressure' in self.variables
and 'pressure' not in ds.data_vars):
tmp = np.zeros(ds['zg'].shape)

if 'number' in ds.dimensions:
tmp[:] = 100 * ds['level'].values[
None, None, :, None, None]
else:
tmp[:] = 100 * ds['level'].values[
None, :, None, None]

ds['pressure'] = (ds['zg'].dims, tmp)
return ds

def convert_z(self, ds, name):
"""Convert z to given height variable
Parameters
----------
ds : Dataset
xr.Dataset() object for new file
name : str
Variable name. e.g. zg or orog, typically
if ('pressure' in self.variables
and 'pressure' not in ds.variables):
tmp = np.zeros(ds.variables['zg'].shape)
for i in range(tmp.shape[1]):
tmp[:, i, :, :] = ds.variables['level'][i] * 100
Returns
-------
ds : Dataset
xr.Dataset() object for new file with new height variable written.
"""
ds['z'] = (ds['z'].dims, ds['z'].values / 9.81)
ds.rename({'z': name})
return ds

_ = ds.createVariable('pressure',
np.float32,
dimensions=dims)
ds.variables['pressure'][:] = tmp[...]
ds.variables['pressure'].long_name = 'Pressure'
ds.variables['pressure'].units = 'Pa'
def process_level_file(self):
"""Convert geopotential to geopotential height."""
tmp_file = self.get_tmp_file(self.level_file)
with xr.open_dataset(self.level_file) as ds:
ds = self.convert_z(ds, name='zg')
ds = self.map_vars(ds)
ds = self.shift_temp(ds)
ds = self.add_pressure(ds)
ds.to_netcdf(tmp_file)

os.system(f'mv {tmp_file} {self.level_file}')
logger.info(f'Finished processing {self.level_file}. Moved '
Expand Down Expand Up @@ -629,17 +610,10 @@ def already_pruned(cls, infile, prune_variables):
if prune_variables is None:
logger.info('Received prune_variables=None. Skipping pruning.')
return
else:
logger.info(f'Received prune_variables={prune_variables}.')

pruned = True
with Dataset(infile, 'r') as ds:
variables = [var for var in ds.variables
if var not in ('time', 'latitude', 'longitude')]
for var in variables:
if not any(name in var for name in prune_variables):
logger.info(f'Pruning {var} in {infile}.')
pruned = False
with xr.open_dataset(infile) as ds:
check_variables = [var for var in ds.data_vars
if 'level' in ds[var].dims]
pruned = len(check_variables) == 0
return pruned

@classmethod
Expand All @@ -649,32 +623,16 @@ def prune_output(cls, infile, prune_variables=None):
logger.info('Received prune_variables=None. Skipping pruning.')
return
else:
logger.info(f'Received prune_variables={prune_variables}.')

logger.info(f'Pruning {infile}.')
tmp_file = cls.get_tmp_file(infile)
with Dataset(infile, 'r') as old_ds:
keep_vars = [var for var in old_ds.variables
if var not in prune_variables and var not
in ('time', 'latitude', 'longitude', 'level')]
with Dataset(tmp_file, 'w') as new_ds:
new_ds = cls.init_dims(old_ds, new_ds,
('time', 'latitude', 'longitude'))
for var in keep_vars:
old_var = old_ds[var]
vals = old_var[:]
logger.info(f'Creating variable {var}.')
_ = new_ds.createVariable(
var, old_var.dtype, dimensions=old_var.dimensions)
new_ds[var][:] = vals
if hasattr(old_var, 'units'):
new_ds[var].units = old_var.units
if hasattr(old_var, 'standard_name'):
standard_name = old_var.standard_name
new_ds[var].standard_name = standard_name
if hasattr(old_var, 'long_name'):
new_ds[var].long_name = old_var.long_name
os.system(f'mv {tmp_file} {infile}')
logger.info(f'Pruning {infile}.')
tmp_file = cls.get_tmp_file(infile)
with xr.Dataset(infile) as ds:
keep_vars = {k:v for k, v in dict(ds.data_vars)
if 'level' not in ds[k].dims}
new_coords = {k:v for k, v in dict(ds.coords).items()
if 'level' not in k}
new_ds = xr.Dataset(coords=new_coords, data_vars=keep_vars)
new_ds.to_netcdf(tmp_file)
os.system(f'mv {tmp_file} {infile}')
logger.info(f'Finished pruning variables in {infile}. Moved '
f'{tmp_file} to {infile}.')

Expand Down

0 comments on commit a4a5992

Please sign in to comment.