Skip to content

Commit

Permalink
pyarrow: Check compatibility of pyarrow-backed pandas objects with nu…
Browse files Browse the repository at this point in the history
…meric dtypes (#2774)

* Ensure that pyarrow backed pandas.Series can be read

Install pyarrow as an optional dependency, and check
that pandas.Series objects backed by pyarrow dtypes
(e.g. 'uint8[pyarrow]') can be read by virtualfile_from_vectors.

* Ensure that pygmt.info can work with pyarrow int64/float64 dtypes

Check that pandas.Series and pandas.DataFrame objects
backed by pyarrow dtypes (e.g. 'int64[pyarrow]' and
'float64[pyarrow]') can be read by pygmt.info.

* Add xfail test for test_geopandas_plot_int_dtypes casting to pyarrow int

Geopandas doesn't support casting to pyarrow dtypes
like 'int32[pyarrow]' and 'int64[pyarrow]' yet, but adding an
xfail test so that we don't forget to test in the future.

* Clarify reason for test_geopandas_plot_int_dtypes xfail

Actually, casting to pyarrow integer dtypes work, but
writing to the temporary OGR_GMT file doesn't.

* Add optional pyarrow dependency to ci_test_dev and ci_tests_legacy

Ensure that previous and future versions of GMT are compatible with PyArrow too.

* Add note about support of PyArrow dtypes to doc/install.rst

Mention that PyGMT does have some initial support of Pandas objects
backed by PyArrow-dtype arrays, but only uint/int/float dtypes for now.

* Use importlib.util.find_spec instead of try-except block

Cleaner way to check if pyarrow is installed or not.

---------

Co-authored-by: Dongdong Tian <[email protected]>
  • Loading branch information
weiji14 and seisman committed Dec 16, 2023
1 parent 66d4d5c commit 25914d8
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
optional-packages: ''
- python-version: '3.12'
numpy-version: '1.26'
optional-packages: ' contextily geopandas ipython rioxarray sphinx-gallery'
optional-packages: ' contextily geopandas ipython pyarrow rioxarray sphinx-gallery'

timeout-minutes: 30
defaults:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci_tests_dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ jobs:
python -m pip install --pre --prefer-binary \
--extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple \
numpy pandas xarray netCDF4 packaging \
build contextily dvc geopandas ipython rioxarray \
build contextily dvc geopandas ipython pyarrow rioxarray \
'pytest>=6.0' pytest-cov pytest-doctestplus pytest-mpl \
sphinx-gallery
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_tests_legacy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ jobs:
contextily
geopandas
ipython
pyarrow
rioxarray
sphinx-gallery
build
Expand Down
9 changes: 9 additions & 0 deletions doc/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ The following are optional dependencies:
* `GeoPandas <https://geopandas.org>`__: For using and plotting GeoDataFrame objects.
* `RioXarray <https://corteva.github.io/rioxarray>`__: For saving multi-band rasters to GeoTIFFs.

.. note::

If you have `PyArrow <https://arrow.apache.org/docs/python/index.html>`__
installed, PyGMT does have some initial support for ``pandas.Series`` and
``pandas.DataFrame`` objects with Apache Arrow-backed arrays. Specifically,
only uint/int/float dtypes are supported for now. Support for datetime and
string Arrow dtypes are still working in progress. For more details, see
`issue #2800 <https://github.com/GenericMappingTools/pygmt/issues/2800>`__.

Installing GMT and other dependencies
-------------------------------------

Expand Down
16 changes: 11 additions & 5 deletions pygmt/tests/test_clib_virtualfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Test the C API functions related to virtual files.
"""
import os
from importlib.util import find_spec
from itertools import product

import numpy as np
Expand Down Expand Up @@ -321,16 +322,21 @@ def test_virtualfile_from_matrix_slice(dtypes):

def test_virtualfile_from_vectors_pandas(dtypes):
"""
Pass vectors to a dataset using pandas Series.
Pass vectors to a dataset using pandas.Series, checking both numpy and
pyarrow dtypes.
"""
size = 13
if find_spec("pyarrow") is not None:
dtypes.extend([f"{dtype}[pyarrow]" for dtype in dtypes])

for dtype in dtypes:
data = pd.DataFrame(
data={
"x": np.arange(size, dtype=dtype),
"y": np.arange(size, size * 2, 1, dtype=dtype),
"z": np.arange(size * 2, size * 3, 1, dtype=dtype),
}
"x": np.arange(size),
"y": np.arange(size, size * 2, 1),
"z": np.arange(size * 2, size * 3, 1),
},
dtype=dtype,
)
with clib.Session() as lib:
with lib.virtualfile_from_vectors(data.x, data.y, data.z) as vfile:
Expand Down
19 changes: 19 additions & 0 deletions pygmt/tests/test_geopandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest
from pygmt import Figure, info, makecpt, which
from pygmt.helpers.testing import skip_if_no

gpd = pytest.importorskip("geopandas")
shapely = pytest.importorskip("shapely")
Expand Down Expand Up @@ -161,6 +162,24 @@ def test_geopandas_plot3d_non_default_circle():
"int64",
pd.Int32Dtype(),
pd.Int64Dtype(),
pytest.param(
"int32[pyarrow]",
marks=[
skip_if_no(package="pyarrow"),
pytest.mark.xfail(
reason="geopandas doesn't support writing columns with pyarrow dtypes to OGR_GMT yet."
),
],
),
pytest.param(
"int64[pyarrow]",
marks=[
skip_if_no(package="pyarrow"),
pytest.mark.xfail(
reason="geopandas doesn't support writing columns with pyarrow dtypes to OGR_GMT yet."
),
],
),
],
)
@pytest.mark.mpl_image_compare(filename="test_geopandas_plot_int_dtypes.png")
Expand Down
18 changes: 15 additions & 3 deletions pygmt/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import xarray as xr
from pygmt import info
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers.testing import skip_if_no

TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
POINTS_DATA = os.path.join(TEST_DATA_DIR, "points.txt")
Expand Down Expand Up @@ -74,16 +75,27 @@ def test_info_2d_list():
assert output == expected_output


def test_info_series():
@pytest.mark.parametrize(
"dtype",
["int64", pytest.param("int64[pyarrow]", marks=skip_if_no(package="pyarrow"))],
)
def test_info_series(dtype):
"""
Make sure info works on a pandas.Series input.
"""
output = info(pd.Series(data=[0, 4, 2, 8, 6]))
output = info(pd.Series(data=[0, 4, 2, 8, 6], dtype=dtype))
expected_output = "<vector memory>: N = 5 <0/8>\n"
assert output == expected_output


def test_info_dataframe():
@pytest.mark.parametrize(
"dtype",
[
"float64",
pytest.param("float64[pyarrow]", marks=skip_if_no(package="pyarrow")),
],
)
def test_info_dataframe(dtype):
"""
Make sure info works on pandas.DataFrame inputs.
"""
Expand Down

0 comments on commit 25914d8

Please sign in to comment.