Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stats function moved to be better shared across metrics #1014

Merged
merged 20 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
lee1043 committed Jan 4, 2024
commit 8b44ec4a4de57b9bfe6a56edf8c1b77275be3698
12 changes: 9 additions & 3 deletions pcmdi_metrics/io/select_subset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
def select_subset(ds: xr.Dataset, lat: tuple = None, lon: tuple = None, time: tuple = None) -> xr.Dataset:
import xarray as xr
import xcdat as xc


def select_subset(
ds: xr.Dataset, lat: tuple = None, lon: tuple = None, time: tuple = None
) -> xr.Dataset:
"""
Selects a subset of the given xarray dataset based on specified latitude, longitude, and time ranges.

Expand All @@ -21,13 +27,13 @@ def select_subset(ds: xr.Dataset, lat: tuple = None, lon: tuple = None, time: tu
lon_tuple = (110, 130) # Longitude range
time_tuple = (cftime.datetime(1850, 1, 1, 0, 0, 0, 0),
cftime.datetime(1851, 12, 31, 23, 59, 59, 0)) # Time range

# Load your xarray dataset (ds) here

# Select subset based on specified ranges
ds_subset = select_subset(ds, lat=lat_tuple, lon=lon_tuple, time=time_tuple)
```
"""
"""
sel_keys = {}
if lat is not None:
lat_key = xc.axis.get_dim_keys(ds, axis="Y")
Expand Down
148 changes: 148 additions & 0 deletions pcmdi_metrics/io/xarray_datasets_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Union

import xarray as xr
import xcdat as xc

# Retrieve coordinate key names


def get_axis_list(ds: Union[xr.Dataset, xr.DataArray]) -> list[str]:
axes = list(ds.coords.keys())
return axes


def get_time_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
time_key = xc.get_dim_keys(ds, "T")
except Exception:
axes = get_axis_list(ds)
time_key = [k for k in axes if k.lower() in ["time"]][0]
return time_key


def get_lat_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
lat_key = xc.get_dim_keys(ds, "Y")
except Exception:
axes = get_axis_list(ds)
lat_key = [k for k in axes if k.lower() in ["lat", "latitude"]][0]
return lat_key


def get_lon_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
try:
lon_key = xc.get_dim_keys(ds, "X")
except Exception:
axes = get_axis_list(ds)
lon_key = [k for k in axes if k.lower() in ["lon", "longitude"]][0]
return lon_key


# Retrieve bounds key names


def get_time_bounds_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
lat_key = get_time_key(ds)
return ds[lat_key].attrs["bounds"]


def get_lat_bounds_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
lat_key = get_lat_key(ds)
return ds[lat_key].attrs["bounds"]


def get_lon_bounds_key(ds: Union[xr.Dataset, xr.DataArray]) -> str:
lon_key = get_lon_key(ds)
return ds[lon_key].attrs["bounds"]


# Extract coordinate data


def get_time(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
time_key = get_time_key(ds)
time = ds[time_key]
return time


def get_longitude(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
lon_key = get_lon_key(ds)
lon = ds[lon_key]
return lon


def get_latitude(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
lat_key = get_lat_key(ds)
lat = ds[lat_key]
return lat


# Extract coordinate bounds data


def get_time_bounds(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
time_bounds_key = get_time_bounds_key(ds)
time_bounds = ds[time_bounds_key]
return time_bounds


def get_longitude_bounds(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
lon_bounds_key = get_lon_bounds_key(ds)
lon_bounds = ds[lon_bounds_key]
return lon_bounds


def get_latitude_bounds(ds: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
lat_bounds_key = get_lat_bounds_key(ds)
lat_bounds = ds[lat_bounds_key]
return lat_bounds


# Select subset


def select_subset(
ds: xr.Dataset, lat: tuple = None, lon: tuple = None, time: tuple = None
) -> xr.Dataset:
"""
Selects a subset of the given xarray dataset based on specified latitude, longitude, and time ranges.

Parameters:
- ds (xr.Dataset): The input xarray dataset.
- lat (tuple, optional): Latitude range in the form of (min, max).
- lon (tuple, optional): Longitude range in the form of (min, max).
- time (tuple, optional): Time range. If time is specified, it should be in the form of (start_time, end_time),
where start_time and end_time can be integers, floats, or cftime.DatetimeProlepticGregorian objects.

Returns:
- xr.Dataset: Subset of the input dataset based on the specified latitude, longitude, and time ranges.

Example Usage:
```
import cftime

# Define latitude, longitude, and time ranges
lat_tuple = (30, 50) # Latitude range
lon_tuple = (110, 130) # Longitude range
time_tuple = (cftime.datetime(1850, 1, 1, 0, 0, 0, 0),
cftime.datetime(1851, 12, 31, 23, 59, 59, 0)) # Time range

# Load your xarray dataset (ds) here

# Select subset based on specified ranges
ds_subset = select_subset(ds, lat=lat_tuple, lon=lon_tuple, time=time_tuple)
```
"""
sel_keys = {}
if lat is not None:
lat_key = get_lat_key(ds)
sel_keys[lat_key] = slice(*lat)
if lon is not None:
lon_key = get_lon_key(ds)
sel_keys[lon_key] = slice(*lon)
if time is not None:
time_key = get_time_key(ds)
sel_keys[time_key] = slice(*time)

ds = ds.sel(**sel_keys)
return ds