Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the _load_remote_dataset function to load tiled and non-tiled grids in a consistent way #3120

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 40 additions & 26 deletions pygmt/datasets/load_remote_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, NamedTuple
from typing import TYPE_CHECKING, ClassVar, Literal, NamedTuple

from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import kwargs_to_strings
from pygmt.io import load_dataarray
from pygmt.src import grdcut, which
from pygmt.helpers import build_arg_list, kwargs_to_strings
from pygmt.src import which

if TYPE_CHECKING:
import xarray as xr
Expand Down Expand Up @@ -344,7 +344,7 @@ def _load_remote_dataset(
dataset_prefix: str,
resolution: str,
region: str | list,
registration: str,
registration: Literal["gridline", "pixel", None],
) -> xr.DataArray:
r"""
Load GMT remote datasets.
Expand All @@ -370,54 +370,68 @@ def _load_remote_dataset(

Returns
-------
grid : :class:`xarray.DataArray`
grid
The GMT remote dataset grid.

Note
----
The returned :class:`xarray.DataArray` doesn't support slice operation for tiled
grids.
The registration and coordinate system type of the returned
:class:`xarray.DataArray` grid can be accessed via the GMT accessors (i.e.,
``grid.gmt.registration`` and ``grid.gmt.gtype`` respectively). However, these
properties may be lost after specific grid operations (such as slicing) and will
need to be manually set before passing the grid to any PyGMT data processing or
plotting functions. Refer to :class:`pygmt.GMTDataArrayAccessor` for detailed
explanations and workarounds.
"""
dataset = datasets[dataset_name]

# Check resolution
if resolution not in dataset.resolutions:
raise GMTInvalidInput(
f"Invalid resolution '{resolution}' for {dataset.title} dataset. "
f"Available resolutions are: {', '.join(dataset.resolutions)}."
)
resinfo = dataset.resolutions[resolution]

# check registration
valid_registrations = dataset.resolutions[resolution].registrations
# Check registration
if registration is None:
# use gridline registration unless only pixel registration is available
registration = "gridline" if "gridline" in valid_registrations else "pixel"
# Use gridline registration unless only pixel registration is available
registration = "gridline" if "gridline" in resinfo.registrations else "pixel"
elif registration in ("pixel", "gridline"):
if registration not in valid_registrations:
if registration not in resinfo.registrations:
raise GMTInvalidInput(
f"{registration} registration is not available for the "
f"{resolution} {dataset.title} dataset. Only "
f"{valid_registrations[0]} registration is available."
f"{resinfo.registrations[0]} registration is available."
)
else:
raise GMTInvalidInput(
f"Invalid grid registration: '{registration}', should be either 'pixel', "
"'gridline' or None. Default is None, where a gridline-registered grid is "
"returned unless only the pixel-registered grid is available."
)
reg = f"_{registration[0]}"

# different ways to load tiled and non-tiled grids.
# Known issue: tiled grids don't support slice operation
# See https://github.com/GenericMappingTools/pygmt/issues/524
if region is None:
if dataset.resolutions[resolution].tiled:
raise GMTInvalidInput(
f"'region' is required for {dataset.title} resolution '{resolution}'."
fname = f"@{dataset_prefix}{resolution}_{registration[0]}"
Copy link
Member Author

@seisman seisman Apr 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also propose to

  • Rename dataset_name to name
  • Rename dataset_prefix to prefix
  • Remove the trailing _ from dataset_prefix, e.g., dataset_prefix="earth_relief_" should be prefix="earth_relief".

Please leave your comments before I make the changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's better to do it in a separate PR to make this PR small for review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue opened at #3190.

if resinfo.tiled and region is None:
raise GMTInvalidInput(
f"'region' is required for {dataset.title} resolution '{resolution}'."
)

# Currently, only grids are supported. Will support images in the future.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that: currently we use -Tg so it only works for grids. For images, we need to use -Ti.

To support images, we need to know whether the remote dataset is a grid or an image. It means we need to add the data kind information to the GMTRemoteDataset class.

kwdict = {"T": "g", "R": region} # region can be None
with Session() as lib:
with lib.virtualfile_out(kind="grid") as voutgrd:
lib.call_module(
module="read",
args=[fname, voutgrd, *build_arg_list(kwdict)],
)
fname = which(f"@{dataset_prefix}{resolution}{reg}", download="a")
grid = load_dataarray(fname, engine="netcdf4")
else:
grid = grdcut(f"@{dataset_prefix}{resolution}{reg}", region=region)
grid = lib.virtualfile_to_raster(outgrid=None, vfname=voutgrd)

# Full path to the grid if not tiled grids.
source = which(fname, download="a") if not resinfo.tiled else None
# Manually add source to xarray.DataArray encoding to make the GMT accessors work.
if source:
grid.encoding["source"] = source
Comment on lines +430 to +434
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the GMT accessors require the grid.encoding["source"] to work. So, for non-tiled gris, we get the full path to the grid and set grid.encoding["source"]. For tiled grids, grid.encoding["source"] is undefined.


# Add some metadata to the grid
grid.name = dataset.name
Expand Down
5 changes: 2 additions & 3 deletions pygmt/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,8 @@ def test_accessor_grid_source_file_not_exist():
# Registration and gtype are correct
assert grid.gmt.registration == 1
assert grid.gmt.gtype == 1
# The source grid file is defined but doesn't exist
assert grid.encoding["source"].endswith(".nc")
assert not Path(grid.encoding["source"]).exists()
# The source grid file is undefined.
assert grid.encoding.get("source") is None

# For a sliced grid, fallback to default registration and gtype,
# because the source grid file doesn't exist.
Expand Down