Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed May 2, 2024
1 parent 1c97c33 commit 953f1a5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 29 deletions.
32 changes: 18 additions & 14 deletions pcmdi_metrics/mjo/lib/lib_mjo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
Code written by Jiwoo Lee, LLNL. Feb. 2019
Inspired by Daehyun Kim and Min-Seop Ahn's MJO metrics.
Code update history
2024-05 converted to use xcdat as base building block (Jiwoo Lee)
Reference:
Ahn, MS., Kim, D., Sperber, K.R. et al. Clim Dyn (2017) 49: 4023.
https://doi.org/10.1007/s00382-017-3558-4
Expand Down Expand Up @@ -36,7 +39,7 @@ def subSliceSegment(
ds: Union[xr.Dataset, xr.DataArray], year: int, mon: int, day: int, length: int
) -> Union[xr.Dataset, xr.DataArray]:
"""
Note: From given cdms array (3D: time and spatial 2D)
Note: From given array (3D: time and spatial 2D)
Subslice to get segment with given length starting from given time.
input
- ds: xarray dataset or dataArray
Expand All @@ -56,7 +59,7 @@ def subSliceSegment(
) # slie 180 time steps starting from above index


def get_daily_ano_segment(d_seg, data_var):
def get_daily_ano_segment(d_seg: xr.Dataset, data_var: str) -> xr.Dataset:
"""
Note: 1. Get daily time series (3D: time and spatial 2D)
2. Meridionally average (2D: time and spatial, i.e., longitude)
Expand All @@ -83,12 +86,13 @@ def get_daily_ano_segment(d_seg, data_var):
return d_seg_x_ano


def space_time_spectrum(d_seg_x_ano, data_var):
def space_time_spectrum(d_seg_x_ano: xr.Dataset, data_var: str) -> np.ndarray:
"""
input
- d: 2d cdms MV2 array (t (time), n (space))
- d: xarray dataset that contains 2d DataArray (t (time), n (space)) named as `data_var`
- data_var: name of the 2d DataArray
output
- p: 2d array for power
- p: 2d numpy array for power
NOTE: Below code taken from
https://github.com/CDAT/wk/blob/2b953281c7a4c5d0ac2d79fcc3523113e31613d5/WK/process.py#L188
"""
Expand Down Expand Up @@ -123,7 +127,7 @@ def taper(data):
"""
Note: taper first and last 45 days with cosine window, using scipy.signal function
input
- data: cdms 2d array (t, n) t: time, n: space (meridionally averaged)
- data: 2d array (t, n) t: time, n: space (meridionally averaged)
output:
- data: tapered data
"""
Expand All @@ -134,17 +138,15 @@ def taper(data):
return data2


def generate_axes_and_decorate(Power, NT, NL):
def generate_axes_and_decorate(Power, NT: int, NL: int) -> xr.DataArray:
"""
Note: Generates axes for the decoration
input
- Power: 2d numpy array
- NT: integer, number of time step
- NL: integer, number of spatial grid
output
- Power: decorated 2d cdms array
- ff: frequency axis
- ss: wavenumber axis
- xr.DataArray that contains Power 2d DataArray that has frequency and zonalwavenumber axes
"""
# frequency
ff = []
Expand Down Expand Up @@ -177,7 +179,7 @@ def generate_axes_and_decorate(Power, NT, NL):
return da


def output_power_spectra(NL, NT, Power):
def output_power_spectra(NL: int, NT: int, Power):
"""
Below code taken and modified from Daehyun Kim's Fortran code (MSD/level_2/sample/stps/stps.sea.f.sample)
"""
Expand Down Expand Up @@ -225,12 +227,14 @@ def output_power_spectra(NL, NT, Power):
# return OEE


def write_netcdf_output(da, fname):
def write_netcdf_output(da: xr.DataArray, fname):
"""
Note: write array in a netcdf file
input
- d: array
- fname: string. directory path and name of the netcd file, without .nc
- d: xr.DataArray object
- fname: string of filename. Directory path that includes file name without .nc
output
- None
"""
ds = xr.Dataset({da.name: da})
ds.to_netcdf(fname + ".nc")
Expand Down
9 changes: 5 additions & 4 deletions pcmdi_metrics/mjo/lib/plot_wavenumber_frequency_power.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
import os

import cdms2
import matplotlib.cm
import matplotlib.pyplot as plt
import xarray as xr
from matplotlib.patches import Rectangle


def plot_power(d, title, fout, ewr=None):
def plot_power(d: xr.DataArray, title: str, fout: str, ewr=None):
x = d["frequency"]
y = d["zonalwavenumber"]

Expand Down Expand Up @@ -132,8 +132,9 @@ def plot_power(d, title, fout, ewr=None):

imgdir = "."

f = cdms2.open(os.path.join(datadir, ncfile))
d = f("power")
ds = xr.open_dataset(os.path.join(datadir, ncfile))
d = ds["power"]

fout = os.path.join(imgdir, pngfilename)

plot_power(d, title, fout, ewr=ewr)
8 changes: 4 additions & 4 deletions pcmdi_metrics/mjo/scripts/post_process_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import glob
import os

import cdms2
import xarray as xr
from lib_mjo import calculate_ewr
from plot_wavenumber_frequency_power import plot_power

Expand Down Expand Up @@ -48,10 +48,9 @@ def main():
ncfile = (
"_".join([mip, model, exp, run, "mjo", period, "cmmGrid"]) + ".nc"
)
f = cdms2.open(os.path.join(datadir, ncfile))
d = f("power")
ds = xr.open_dataset(os.path.join(datadir, ncfile))
d = ds["power"]
d_runs.append(d)
f.close()
title = (
mip.upper()
+ ": "
Expand All @@ -69,6 +68,7 @@ def main():
fout = os.path.join(imgdir, pngfilename)
# plot
plot_power(d, title, fout, ewr)
ds.close()
except Exception:
print(model, run, "cannnot load")
pass
Expand Down
17 changes: 10 additions & 7 deletions pcmdi_metrics/mjo/scripts/post_process_plot_ensemble_mean.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import glob
import os

import cdms2
import MV2
import numpy as np
import xarray as xr
from lib_mjo import calculate_ewr
from plot_wavenumber_frequency_power import plot_power

Expand Down Expand Up @@ -62,18 +62,21 @@ def main():
)
+ ".nc"
)
f = cdms2.open(os.path.join(datadir, ncfile))
d = f("power")

ds = xr.open_dataset(os.path.join(datadir, ncfile))
d = ds["power"]

d_runs.append(d)
f.close()

except Exception as err:
print(model, run, "cannnot load:", err)
pass

if run == runs_list[-1]:
num_runs = len(d_runs)
# ensemble mean
d_avg = MV2.average(d_runs, axis=0)
d_avg.setAxisList(d.getAxisList())
d_avg = np.average(d_runs, axis=0)
# d_avg.setAxisList(d.getAxisList())
title = (
mip.upper()
+ ": "
Expand Down

0 comments on commit 953f1a5

Please sign in to comment.