-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial commit for wrapping the grdmath function for #916 which is a "Raster calculator for grids (element by element)". Original GMT `grdmath` documentation is at https://docs.generic-mapping-tools.org/6.2/grdmath.html. Implementation works by building and storing a computational graph of the grid operations in a GrdMathCalc class object. An output NetCDF grid or xarray.DataArray is produced upon calling `.compute()` or by setting `outgrid=True`.
- Loading branch information
Showing
5 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,6 +89,7 @@ Operations on grids: | |
.. autosummary:: | ||
:toctree: generated | ||
|
||
GrdMathCalc | ||
grdclip | ||
grdcut | ||
grdfill | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
grdmath - Raster calculator for grids (element by element) | ||
""" | ||
|
||
from pygmt.clib import Session | ||
from pygmt.helpers import ( | ||
GMTTempFile, | ||
build_arg_string, | ||
dummy_context, | ||
fmt_docstring, | ||
kwargs_to_strings, | ||
use_alias, | ||
) | ||
from pygmt.io import load_dataarray | ||
|
||
|
||
class GrdMathCalc: | ||
""" | ||
Raster calculator for grids (element by element). | ||
""" | ||
|
||
def __init__(self, arg_str=None): | ||
self.arg_str = "" if arg_str is None else arg_str | ||
|
||
def __repr__(self): | ||
return f"gmt grdmath {self.arg_str}" | ||
|
||
def compute(self): | ||
""" | ||
Perform the grdmath computation and returns an xarray.DataArray object. | ||
""" | ||
with Session() as lib: | ||
with GMTTempFile(suffix=".nc") as tmpfile: | ||
outgrid = tmpfile.name | ||
# print(f"Executing gmt grdmath {self.arg_str}") | ||
lib.call_module("grdmath", f"{self.arg_str} = {outgrid}") | ||
return load_dataarray(outgrid) | ||
|
||
@classmethod | ||
@fmt_docstring | ||
@use_alias(R="region", V="verbose") | ||
@kwargs_to_strings(R="sequence") | ||
def grdmath(cls, operator, ingrid=None, outgrid=None, old_arg_str=None, **kwargs): | ||
""" | ||
Raster calculator for grids (element-wise operations). | ||
Full option list at :gmt-docs:`grdmath.html` | ||
{aliases} | ||
Parameters | ||
---------- | ||
operator : str | ||
The mathematical operator to use. Full list of available | ||
operations is at :gmt-docs:`grdmath.html#operators`. | ||
ingrid : str or float | ||
outgrid : str or bool or None | ||
The name of a 2-D grid file that will hold the final result. Set to | ||
True to output to an :class:`xarray.DataArray`. Default is None, | ||
which will save the computation graph, to be computed or appended | ||
to with more operations later. | ||
old_arg_str : str | ||
Returns | ||
------- | ||
ret : pygmt.GrdMathCalc or xarray.DataArray or None | ||
Return type depends on whether the ``outgrid`` parameter is set: | ||
- :class:`pygmt.GrdMathCalc` if ``outgrid`` is None (computational | ||
graph is created, and more operations can be appended) | ||
- :class:`xarray.DataArray` if ``outgrid`` is True | ||
- None if ``outgrid`` is a str (grid output will be stored in file | ||
set by ``outgrid``) | ||
""" | ||
old_arg_str = old_arg_str or "" # Convert None to empty string | ||
|
||
with Session() as lib: | ||
if isinstance(ingrid, GrdMathCalc): | ||
file_context = dummy_context(ingrid.arg_str) | ||
else: | ||
file_context = lib.virtualfile_from_data( | ||
check_kind="raster", data=ingrid | ||
) | ||
|
||
with file_context as infile: | ||
arg_str = " ".join( | ||
[old_arg_str, infile, build_arg_string(kwargs)] | ||
).strip() | ||
arg_str += f" {operator}" | ||
|
||
# If no output is requested, just build computational graph | ||
if outgrid is None: | ||
result = cls(arg_str=arg_str) | ||
|
||
# If output is requested, compute output grid | ||
elif outgrid is not None: | ||
with GMTTempFile(suffix=".nc") as tmpfile: | ||
if outgrid is True: | ||
outgrid = tmpfile.name | ||
arg_str += f" = {outgrid}" | ||
# print(f"Executing gmt grdmath {arg_str}") | ||
lib.call_module("grdmath", arg_str) | ||
result = ( | ||
load_dataarray(outgrid) if outgrid == tmpfile.name else None | ||
) | ||
|
||
return result | ||
|
||
def sqrt(self, ingrid, outgrid=None, **kwargs): | ||
""" | ||
sqrt (A). 1 input, 1 output. | ||
""" | ||
return self.grdmath(operator="SQRT", ingrid=ingrid, outgrid=outgrid, **kwargs) | ||
|
||
def std(self, ingrid, outgrid=None, **kwargs): | ||
""" | ||
Standard deviation of A. 1 input, 1 output. | ||
""" | ||
return self.grdmath(operator="STD", ingrid=ingrid, outgrid=outgrid, **kwargs) | ||
|
||
def multiply(self, ingrid, outgrid=None, **kwargs): | ||
""" | ||
A * B. 2 inputs, 1 output | ||
""" | ||
return self.grdmath( | ||
operator="MUL", | ||
ingrid=ingrid, | ||
outgrid=outgrid, | ||
old_arg_str=self.arg_str, | ||
**kwargs, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
""" | ||
Tests for grdmath. | ||
""" | ||
import numpy as np | ||
import pytest | ||
import xarray as xr | ||
import xarray.testing as xrt | ||
from pygmt import GrdMathCalc | ||
from pygmt.datasets import load_earth_relief | ||
|
||
|
||
@pytest.fixture(scope="module", name="grid") | ||
def fixture_grid(): | ||
""" | ||
Load the grid data from the sample earth_relief file. | ||
""" | ||
return load_earth_relief(resolution="01d", region=[0, 3, 6, 9]) | ||
|
||
|
||
def test_grdmath_sqrt(grid): | ||
""" | ||
Test grdmath SQRT operation. | ||
""" | ||
grdcalc = GrdMathCalc() | ||
actual = grdcalc.sqrt(ingrid="@earth_relief_01d", outgrid=True, region=[0, 3, 6, 9]) | ||
expected = np.sqrt(grid) | ||
xrt.assert_allclose(actual, expected) | ||
|
||
|
||
def test_grdmath_directly(grid): | ||
""" | ||
Test grdmath LOG operation directly using GrdMathCalc.grdmath classmethod. | ||
""" | ||
actual = GrdMathCalc.grdmath( | ||
operator="LOG", ingrid="@earth_relief_01d", outgrid=True, region=[0, 3, 6, 9] | ||
) | ||
expected = np.log(grid) | ||
xrt.assert_allclose(actual, expected) | ||
|
||
|
||
def test_grdmath_chained_operations(): | ||
""" | ||
Test grdmath chaining several intermediate computations together before | ||
producing final xarray.DataArray grid by calling `.compute()`. | ||
""" | ||
grdcalc = GrdMathCalc() | ||
assert grdcalc.arg_str == "" | ||
|
||
grid1 = grdcalc.sqrt(ingrid="@earth_relief_01d_p") | ||
assert grid1.arg_str == "@earth_relief_01d_p SQRT" | ||
|
||
grid2 = grdcalc.std(ingrid="@earth_relief_01d_g") | ||
assert grid2.arg_str == "@earth_relief_01d_g STD" | ||
|
||
grid3 = grid1.multiply(ingrid=grid2, region=[0, 3, 6, 9]) | ||
assert ( | ||
grid3.arg_str == "@earth_relief_01d_p SQRT @earth_relief_01d_g STD " | ||
"-R0/3/6/9 MUL" | ||
) | ||
|
||
actual_grid = grid3.compute() | ||
expected_grid = xr.DataArray( | ||
data=[ | ||
[5297.7134, 3454.636, 887.0686], | ||
[9147.278, 6667.7827, 5760.257], | ||
[8469.841, 8661.23, 7892.7505], | ||
], | ||
coords=dict(lon=[0.5, 1.5, 2.5], lat=[6.5, 7.5, 8.5]), | ||
dims=["lat", "lon"], | ||
) | ||
xrt.assert_allclose(actual_grid, expected_grid) |