Skip to content

Commit

Permalink
in progress ...
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 committed Apr 30, 2024
1 parent 5aac6f7 commit 680dd7c
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 47 deletions.
6 changes: 4 additions & 2 deletions pcmdi_metrics/mjo/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
decorate_2d_array_axes,
generate_axes_and_decorate,
get_daily_ano_segment,
interp2commonGrid,
#interp2commonGrid,
interp2commonGrid_xcdat,
mjo_metrics_to_json,
output_power_spectra,
space_time_spectrum,
subSliceSegment,
# subSliceSegment,
subSliceSegment_xcdat,
taper,
unit_conversion,
write_netcdf_output,
Expand Down
39 changes: 38 additions & 1 deletion pcmdi_metrics/mjo/lib/lib_mjo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from scipy import signal

import pcmdi_metrics
import xarray as xr

from typing import Union
from pcmdi_metrics.io import get_time_key
from pcmdi_metrics.utils import create_target_grid, regrid
from pcmdi_metrics.io import select_subset


def interp2commonGrid(d, dlat, debug=False):
Expand All @@ -34,7 +40,20 @@ def interp2commonGrid(d, dlat, debug=False):
return d2


def subSliceSegment(d, year, mon, day, length):
def interp2commonGrid_xcdat(ds, data_var, dlat, dlon=None, debug=False):
if dlon is None:
dlon = dlat
nlat = int(180/dlat)
nlon = int(360/dlon)
grid = create_target_grid(target_grid_resolution=f"{dlat}x{dlon}")
ds_regrid = regrid(ds, data_var, grid)
ds_regrid_subset = select_subset(ds, lat=(-10, 10))
if debug:
print("debug: ds_regrid_subset[data_var] shape:", ds_regrid_subset[data_var].shape)
return ds_regrid_subset


def subSliceSegment(ds, data_var, year, mon, day, length):
"""
Note: From given cdms array (3D: time and spatial 2D)
Subslice to get segment with given length starting from given time.
Expand All @@ -56,6 +75,24 @@ def subSliceSegment(d, year, mon, day, length):
return d2


def subSliceSegment_xcdat(ds: Union[xr.Dataset, xr.DataArray], year: int, mon: int, day:int, length: int) -> Union[xr.Dataset, xr.DataArray]:
"""
Note: From given cdms array (3D: time and spatial 2D)
Subslice to get segment with given length starting from given time.
input
- ds: xarray dataset or dataArray
- year: segment starting year (integer)
- mon: segment starting month (integer)
- day: segement starting day (integer)
- length: segment length (integer)
"""

time_key = get_time_key(ds)
n = list(ds[time_key].values).index(ds.sel(time=f'{year:04}-{mon:02}-{day:02}')[time_key])

return ds.isel(time=slice(n, n + length)) # slie 180 time steps starting from above index


def Remove_dailySeasonalCycle(d, d_cyc):
"""
Note: Remove daily seasonal cycle
Expand Down
126 changes: 82 additions & 44 deletions pcmdi_metrics/mjo/lib/mjo_metric_calc.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,33 @@
import os

import cdms2
import cdtime
#import cdms2
#import cdtime
import MV2
import numpy as np

from datetime import datetime

from pcmdi_metrics.mjo.lib import (
Remove_dailySeasonalCycle,
calculate_ewr,
generate_axes_and_decorate,
get_daily_ano_segment,
interp2commonGrid,
#interp2commonGrid,
interp2commonGrid_xcdat,
output_power_spectra,
space_time_spectrum,
subSliceSegment,
unit_conversion,
#subSliceSegment,
subSliceSegment_xcdat,
#unit_conversion,
write_netcdf_output,
)

from .debug_chk_plot import debug_chk_plot
from .plot_wavenumber_frequency_power import plot_power

from pcmdi_metrics.io import xcdat_open, get_time, get_latitude, get_longitude, get_time_key
from pcmdi_metrics.utils import adjust_units


def mjo_metric_ewr_calculation(
mip,
Expand All @@ -34,38 +41,56 @@ def mjo_metric_ewr_calculation(
degX,
UnitsAdjust,
inputfile,
var,
startYear,
endYear,
segmentLength,
dir_paths,
season="NDJFMA",
data_var: str,
startYear: int,
endYear: int,
segmentLength: int,
dir_paths: str,
season: str="NDJFMA",
):
# Open file to read daily dataset
if debug:
print("debug: open file")
f = cdms2.open(inputfile)
d = f[var]
tim = d.getTime()
comTim = tim.asComponentTime()
lat = d.getLatitude()
lon = d.getLongitude()
#f = cdms2.open(inputfile)
#d = f[data_var]
ds = xcdat_open(inputfile)

#tim = d.getTime()
#comTim = tim.asComponentTime()
#lat = d.getLatitude()
#lon = d.getLongitude()

#tim = get_time(ds)
lat = get_latitude(ds)
lon = get_longitude(ds)

# Get starting and ending year and month
if debug:
print("debug: check time")
first_time = comTim[0]
last_time = comTim[-1]

#first_time = comTim[0]
#last_time = comTim[-1]

time_key = get_time_key(ds)
first_time = ds.indexes[time_key].to_datetimeindex()[0].to_pydatetime()
last_time = ds.indexes[time_key].to_datetimeindex()[-1].to_pydatetime()

if season == "NDJFMA":
# Adjust years to consider only when continuous NDJFMA is available
"""
if first_time > cdtime.comptime(startYear, 11, 1):
startYear += 1
if last_time < cdtime.comptime(endYear, 4, 30):
endYear -= 1

"""
if first_time > datetime(startYear, 11, 1):
startYear += 1
if last_time < datetime(endYear, 4, 30):
endYear -= 1

# Number of grids for 2d fft input
NL = len(d.getLongitude()) # number of grid in x-axis (longitude)
#NL = len(d.getLongitude()) # number of grid in x-axis (longitude)
NL = len(lon.values) # number of grid in x-axis (longitude)
if cmmGrid:
NL = int(360 / degX)
NT = segmentLength # number of time step for each segment (need to be an even number)
Expand All @@ -88,35 +113,46 @@ def mjo_metric_ewr_calculation(
# Store each year's segment in a dictionary: segment[year]
segment = {}
segment_ano = {}
daSeaCyc = MV2.zeros((NT, d.shape[1], d.shape[2]))
#daSeaCyc = MV2.zeros((NT, d.shape[1], d.shape[2]))
daSeaCyc = np.zeros((NT, ds[data_var].shape[1], ds[data_var].shape[2]))

# Loop over years
for year in range(startYear, endYear):
print(year)
segment[year] = subSliceSegment(d, year, mon, day, NT)
#segment[year] = subSliceSegment(d, year, mon, day, NT)
segment[year] = subSliceSegment_xcdat(ds, year, mon, day, NT)
# units conversion
segment[year] = unit_conversion(segment[year], UnitsAdjust)
#segment[year] = unit_conversion(segment[year], UnitsAdjust)
segment[year][data_var] = adjust_units(segment[year][data_var], UnitsAdjust)
# Get climatology of daily seasonal cycle
daSeaCyc = MV2.add(MV2.divide(segment[year], float(numYear)), daSeaCyc)
#daSeaCyc = MV2.add(MV2.divide(segment[year], float(numYear)), daSeaCyc)
daSeaCyc = np.add(np.divide(segment[year][data_var].values, float(numYear)), daSeaCyc)

# Remove daily seasonal cycle from each segment
if numYear > 1:
# Loop over years
for year in range(startYear, endYear):
segment_ano[year] = Remove_dailySeasonalCycle(segment[year], daSeaCyc)
#segment_ano[year] = Remove_dailySeasonalCycle(segment[year], daSeaCyc)
segment_ano[year] = segment[year] - daSeaCyc
else:
segment_ano[year] = segment[year]

# Assign lat/lon to arrays
daSeaCyc.setAxis(1, lat)
daSeaCyc.setAxis(2, lon)
segment_ano[year].setAxis(1, lat)
segment_ano[year].setAxis(2, lon)

""" Space-time power spectra
Handle each segment (i.e. each year) separately.
1. Get daily time series (3D: time and spatial 2D)
2. Meridionally average (2D: time and spatial, i.e., longitude)
3. Get anomaly by removing time mean of the segment
4. Proceed 2-D FFT to get power.
Then get multi-year averaged power after the year loop.
"""
# daSeaCyc.setAxis(1, lat)
# daSeaCyc.setAxis(2, lon)
# segment_ano[year].setAxis(1, lat)
# segment_ano[year].setAxis(2, lon)

# -----------------------------------------------------------------
# Space-time power spectra
# -----------------------------------------------------------------
# Handle each segment (i.e. each year) separately.
# 1. Get daily time series (3D: time and spatial 2D)
# 2. Meridionally average (2D: time and spatial, i.e., longitude)
# 3. Get anomaly by removing time mean of the segment
# 4. Proceed 2-D FFT to get power.
# Then get multi-year averaged power after the year loop.
# -----------------------------------------------------------------

# Define array for archiving power from each year segment
Power = np.zeros((numYear, NT + 1, NL + 1), np.float)
Expand All @@ -129,7 +165,8 @@ def mjo_metric_ewr_calculation(
d_seg = segment_ano[year]
# Regrid: interpolation to common grid
if cmmGrid:
d_seg = interp2commonGrid(d_seg, degX, debug=debug)
#d_seg = interp2commonGrid(d_seg, degX, debug=debug)
d_seg = interp2commonGrid_xcdat(d_seg, degX, debug=debug)
# Subregion, meridional average, and remove segment time mean
d_seg_x_ano = get_daily_ano_segment(d_seg)
# Compute space-time spectrum
Expand Down Expand Up @@ -166,9 +203,9 @@ def mjo_metric_ewr_calculation(
os.makedirs(dir_paths["graphics"], exist_ok=True)
fout = os.path.join(dir_paths["graphics"], output_filename)
if model == "obs":
title = f"OBS ({run})\n{var.capitalize()}, {season} {startYear}-{endYear}"
title = f"OBS ({run})\n{data_var.capitalize()}, {season} {startYear}-{endYear}"
else:
title = f"{mip.upper()}: {model} ({run})\n{var.capitalize()}, {season} {startYear}-{endYear}"
title = f"{mip.upper()}: {model} ({run})\n{data_var.capitalize()}, {season} {startYear}-{endYear}"

if cmmGrid:
title += ", common grid (2.5x2.5deg)"
Expand All @@ -189,5 +226,6 @@ def mjo_metric_ewr_calculation(
d_seg_x_ano, Power, OEE, segment[year], daSeaCyc, segment_ano[year]
)

f.close()
#f.close()
ds.close()
return metrics_result
2 changes: 2 additions & 0 deletions pcmdi_metrics/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .adjust_units import adjust_units

from .custom_season import (
custom_season_average,
custom_season_departure,
Expand Down

0 comments on commit 680dd7c

Please sign in to comment.