Skip to content

Commit

Permalink
Let required_z check work on xarray.Dataset inputs (#1523)
Browse files Browse the repository at this point in the history
Modify if-statement check in `data_kind` helper function
to use `len(data.data_vars)` to check the shape of
xarray.Dataset inputs. Added two regression tests in
test_clib and test_blockm to ensure this works on both the
low-level clib and high-level module APIs.

* Mention required_z parameter in docstring

Co-authored-by: Dongdong Tian <[email protected]>
  • Loading branch information
weiji14 and seisman authored Sep 20, 2021
1 parent c0a8dfa commit bafb8ab
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
2 changes: 2 additions & 0 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,6 +1375,8 @@ def virtualfile_from_data(
extra_arrays : list of 1d arrays
Optional. A list of numpy arrays in addition to x, y and z. All
of these arrays must be of the same size as the x/y/z arrays.
required_z : bool
State whether the 'z' column is required.
Returns
-------
Expand Down
10 changes: 7 additions & 3 deletions pygmt/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def data_kind(data, x=None, y=None, z=None, required_z=False):
x/y : 1d arrays or None
x and y columns as numpy arrays.
z : 1d array or None
z column as numpy array. To be used optionally when x and y
are given.
z column as numpy array. To be used optionally when x and y are given.
required_z : bool
State whether the 'z' column is required.
Returns
-------
Expand Down Expand Up @@ -80,7 +81,10 @@ def data_kind(data, x=None, y=None, z=None, required_z=False):
elif hasattr(data, "__geo_interface__"):
kind = "geojson"
elif data is not None:
if required_z and data.shape[1] < 3:
if required_z and (
getattr(data, "shape", (3, 3))[1] < 3 # np.array, pd.DataFrame
or len(getattr(data, "data_vars", (0, 1, 2))) < 3 # xr.Dataset
):
raise GMTInvalidInput("data must provide x, y, and z columns.")
kind = "matrix"
else:
Expand Down
7 changes: 5 additions & 2 deletions pygmt/tests/test_blockm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
"""
import os

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
import xarray as xr
from pygmt import blockmean, blockmode
from pygmt.datasets import load_sample_bathymetry
from pygmt.exceptions import GMTInvalidInput
Expand All @@ -31,12 +33,13 @@ def test_blockmean_input_dataframe(dataframe):
npt.assert_allclose(output.iloc[0], [245.888877, 29.978707, -384.0])


def test_blockmean_input_table_matrix(dataframe):
@pytest.mark.parametrize("array_func", [np.array, xr.Dataset])
def test_blockmean_input_table_matrix(array_func, dataframe):
"""
Run blockmean using table input that is not a pandas.DataFrame but still a
matrix.
"""
table = dataframe.values
table = array_func(dataframe)
output = blockmean(table=table, spacing="5m", region=[245, 255, 20, 30])
assert isinstance(output, pd.DataFrame)
assert output.shape == (5849, 3)
Expand Down
31 changes: 30 additions & 1 deletion pygmt/tests/test_clib.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,36 @@ def test_virtual_file_bad_direction():
print("This should have failed")


def test_virtualfile_from_data_required_z_matrix():
@pytest.mark.parametrize(
"array_func,kind",
[(np.array, "matrix"), (pd.DataFrame, "vector"), (xr.Dataset, "vector")],
)
def test_virtualfile_from_data_required_z_matrix(array_func, kind):
"""
Test that function works when third z column in a matrix is needed and
provided.
"""
shape = (5, 3)
dataframe = pd.DataFrame(
data=np.arange(shape[0] * shape[1]).reshape(shape), columns=["x", "y", "z"]
)
data = array_func(dataframe)
with clib.Session() as lib:
with lib.virtualfile_from_data(data=data, required_z=True) as vfile:
with GMTTempFile() as outfile:
lib.call_module("info", f"{vfile} ->{outfile.name}")
output = outfile.read(keep_tabs=True)
bounds = "\t".join(
[
f"<{i.min():.0f}/{i.max():.0f}>"
for i in (dataframe.x, dataframe.y, dataframe.z)
]
)
expected = f"<{kind} memory>: N = {shape[0]}\t{bounds}\n"
assert output == expected


def test_virtualfile_from_data_required_z_matrix_missing():
"""
Test that function fails when third z column in a matrix is needed but not
provided.
Expand Down

0 comments on commit bafb8ab

Please sign in to comment.