diff --git a/doc/api/index.rst b/doc/api/index.rst index e08f1c0aacb..11d5317db2f 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -89,6 +89,7 @@ Operations on grids: .. autosummary:: :toctree: generated + GrdMathCalc grdclip grdcut grdfill diff --git a/pygmt/__init__.py b/pygmt/__init__.py index ec4934baddf..6b966492926 100644 --- a/pygmt/__init__.py +++ b/pygmt/__init__.py @@ -30,6 +30,7 @@ from pygmt.session_management import begin as _begin from pygmt.session_management import end as _end from pygmt.src import ( + GrdMathCalc, blockmean, blockmedian, blockmode, diff --git a/pygmt/src/__init__.py b/pygmt/src/__init__.py index a7f1e1c74ef..1b181e921a6 100644 --- a/pygmt/src/__init__.py +++ b/pygmt/src/__init__.py @@ -19,6 +19,7 @@ from pygmt.src.grdimage import grdimage from pygmt.src.grdinfo import grdinfo from pygmt.src.grdlandmask import grdlandmask +from pygmt.src.grdmath import GrdMathCalc from pygmt.src.grdproject import grdproject from pygmt.src.grdsample import grdsample from pygmt.src.grdtrack import grdtrack diff --git a/pygmt/src/grdmath.py b/pygmt/src/grdmath.py new file mode 100644 index 00000000000..665bb736a00 --- /dev/null +++ b/pygmt/src/grdmath.py @@ -0,0 +1,134 @@ +""" +grdmath - Raster calculator for grids (element by element) +""" + +from pygmt.clib import Session +from pygmt.helpers import ( + GMTTempFile, + build_arg_string, + dummy_context, + fmt_docstring, + kwargs_to_strings, + use_alias, +) +from pygmt.io import load_dataarray + + +class GrdMathCalc: + """ + Raster calculator for grids (element by element). + """ + + def __init__(self, arg_str=None): + self.arg_str = "" if arg_str is None else arg_str + + def __repr__(self): + return f"gmt grdmath {self.arg_str}" + + def compute(self): + """ + Perform the grdmath computation and returns an xarray.DataArray object. + """ + with Session() as lib: + with GMTTempFile(suffix=".nc") as tmpfile: + outgrid = tmpfile.name + # print(f"Executing gmt grdmath {self.arg_str}") + lib.call_module("grdmath", f"{self.arg_str} = {outgrid}") + return load_dataarray(outgrid) + + @classmethod + @fmt_docstring + @use_alias(R="region", V="verbose") + @kwargs_to_strings(R="sequence") + def grdmath(cls, operator, ingrid=None, outgrid=None, old_arg_str=None, **kwargs): + """ + Raster calculator for grids (element-wise operations). + + Full option list at :gmt-docs:`grdmath.html` + + {aliases} + + Parameters + ---------- + operator : str + The mathematical operator to use. Full list of available + operations is at :gmt-docs:`grdmath.html#operators`. + + ingrid : str or float + + outgrid : str or bool or None + The name of a 2-D grid file that will hold the final result. Set to + True to output to an :class:`xarray.DataArray`. Default is None, + which will save the computation graph, to be computed or appended + to with more operations later. + + old_arg_str : str + + Returns + ------- + ret : pygmt.GrdMathCalc or xarray.DataArray or None + Return type depends on whether the ``outgrid`` parameter is set: + + - :class:`pygmt.GrdMathCalc` if ``outgrid`` is None (computational + graph is created, and more operations can be appended) + - :class:`xarray.DataArray` if ``outgrid`` is True + - None if ``outgrid`` is a str (grid output will be stored in file + set by ``outgrid``) + """ + old_arg_str = old_arg_str or "" # Convert None to empty string + + with Session() as lib: + if isinstance(ingrid, GrdMathCalc): + file_context = dummy_context(ingrid.arg_str) + else: + file_context = lib.virtualfile_from_data( + check_kind="raster", data=ingrid + ) + + with file_context as infile: + arg_str = " ".join( + [old_arg_str, infile, build_arg_string(kwargs)] + ).strip() + arg_str += f" {operator}" + + # If no output is requested, just build computational graph + if outgrid is None: + result = cls(arg_str=arg_str) + + # If output is requested, compute output grid + elif outgrid is not None: + with GMTTempFile(suffix=".nc") as tmpfile: + if outgrid is True: + outgrid = tmpfile.name + arg_str += f" = {outgrid}" + # print(f"Executing gmt grdmath {arg_str}") + lib.call_module("grdmath", arg_str) + result = ( + load_dataarray(outgrid) if outgrid == tmpfile.name else None + ) + + return result + + def sqrt(self, ingrid, outgrid=None, **kwargs): + """ + sqrt (A). 1 input, 1 output. + """ + return self.grdmath(operator="SQRT", ingrid=ingrid, outgrid=outgrid, **kwargs) + + def std(self, ingrid, outgrid=None, **kwargs): + """ + Standard deviation of A. 1 input, 1 output. + """ + return self.grdmath(operator="STD", ingrid=ingrid, outgrid=outgrid, **kwargs) + + def multiply(self, ingrid, outgrid=None, **kwargs): + """ + A * B. 2 inputs, 1 output + """ + return self.grdmath( + operator="MUL", + ingrid=ingrid, + outgrid=outgrid, + old_arg_str=self.arg_str, + **kwargs, + ) diff --git a/pygmt/tests/test_grdmath.py b/pygmt/tests/test_grdmath.py new file mode 100644 index 00000000000..e62da11346e --- /dev/null +++ b/pygmt/tests/test_grdmath.py @@ -0,0 +1,71 @@ +""" +Tests for grdmath. +""" +import numpy as np +import pytest +import xarray as xr +import xarray.testing as xrt +from pygmt import GrdMathCalc +from pygmt.datasets import load_earth_relief + + +@pytest.fixture(scope="module", name="grid") +def fixture_grid(): + """ + Load the grid data from the sample earth_relief file. + """ + return load_earth_relief(resolution="01d", region=[0, 3, 6, 9]) + + +def test_grdmath_sqrt(grid): + """ + Test grdmath SQRT operation. + """ + grdcalc = GrdMathCalc() + actual = grdcalc.sqrt(ingrid="@earth_relief_01d", outgrid=True, region=[0, 3, 6, 9]) + expected = np.sqrt(grid) + xrt.assert_allclose(actual, expected) + + +def test_grdmath_directly(grid): + """ + Test grdmath LOG operation directly using GrdMathCalc.grdmath classmethod. + """ + actual = GrdMathCalc.grdmath( + operator="LOG", ingrid="@earth_relief_01d", outgrid=True, region=[0, 3, 6, 9] + ) + expected = np.log(grid) + xrt.assert_allclose(actual, expected) + + +def test_grdmath_chained_operations(): + """ + Test grdmath chaining several intermediate computations together before + producing final xarray.DataArray grid by calling `.compute()`. + """ + grdcalc = GrdMathCalc() + assert grdcalc.arg_str == "" + + grid1 = grdcalc.sqrt(ingrid="@earth_relief_01d_p") + assert grid1.arg_str == "@earth_relief_01d_p SQRT" + + grid2 = grdcalc.std(ingrid="@earth_relief_01d_g") + assert grid2.arg_str == "@earth_relief_01d_g STD" + + grid3 = grid1.multiply(ingrid=grid2, region=[0, 3, 6, 9]) + assert ( + grid3.arg_str == "@earth_relief_01d_p SQRT @earth_relief_01d_g STD " + "-R0/3/6/9 MUL" + ) + + actual_grid = grid3.compute() + expected_grid = xr.DataArray( + data=[ + [5297.7134, 3454.636, 887.0686], + [9147.278, 6667.7827, 5760.257], + [8469.841, 8661.23, 7892.7505], + ], + coords=dict(lon=[0.5, 1.5, 2.5], lat=[6.5, 7.5, 8.5]), + dims=["lat", "lon"], + ) + xrt.assert_allclose(actual_grid, expected_grid)