Skip to content

Commit

Permalink
Merge branch 'main' into clib/load-libgmt
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Jan 2, 2024
2 parents bf6083f + 88ab1ca commit 11cb76b
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 8 deletions.
8 changes: 8 additions & 0 deletions pygmt/session_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""
Modern mode session management modules.
"""
import os
import sys

from pygmt.clib import Session
from pygmt.helpers import unique_name


def begin():
Expand All @@ -12,6 +16,10 @@ def begin():
Only meant to be used once for creating the global session.
"""
# On Windows, need to set GMT_SESSION_NAME to a unique value
if sys.platform == "win32":
os.environ["GMT_SESSION_NAME"] = unique_name()

Check warning on line 21 in pygmt/session_management.py

View check run for this annotation

Codecov / codecov/patch

pygmt/session_management.py#L21

Added line #L21 was not covered by tests

prefix = "pygmt-session"
with Session() as lib:
lib.call_module(module="begin", args=prefix)
Expand Down
4 changes: 2 additions & 2 deletions pygmt/src/meca.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def convention_params(convention):
],
"mt": ["mrr", "mtt", "mff", "mrt", "mrf", "mtf", "exponent"],
"partial": ["strike1", "dip1", "strike2", "fault_type", "magnitude"],
"pricipal_axis": [
"principal_axis": [
"t_value",
"t_azimuth",
"t_plunge",
Expand Down Expand Up @@ -401,7 +401,7 @@ def meca( # noqa: PLR0912, PLR0913, PLR0915
# Convert spec to pandas.DataFrame unless it's a file
if isinstance(spec, (dict, pd.DataFrame)): # spec is a dict or pd.DataFrame
# determine convention from dict keys or pd.DataFrame column names
for conv in ["aki", "gcmt", "mt", "partial", "pricipal_axis"]:
for conv in ["aki", "gcmt", "mt", "partial", "principal_axis"]:
if set(convention_params(conv)).issubset(set(spec.keys())):
convention = conv
break
Expand Down
22 changes: 16 additions & 6 deletions pygmt/tests/test_clib_virtualfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def fixture_dtypes():
return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split()


@pytest.fixture(scope="module", name="dtypes_pandas")
def fixture_dtypes_pandas(dtypes):
"""
List of supported pandas dtypes.
"""
dtypes_pandas = dtypes.copy()

if find_spec("pyarrow") is not None:
dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas])

return tuple(dtypes_pandas)


def test_virtual_file(dtypes):
"""
Test passing in data via a virtual file with a Dataset.
Expand Down Expand Up @@ -248,11 +261,10 @@ def test_virtualfile_from_vectors_two_string_or_object_columns(dtype):
assert output == expected


def test_virtualfile_from_vectors_transpose():
def test_virtualfile_from_vectors_transpose(dtypes):
"""
Test transforming matrix columns to virtual file dataset.
"""
dtypes = "float32 float64 int32 int64 uint32 uint64".split()
shape = (7, 5)
for dtype in dtypes:
data = np.arange(shape[0] * shape[1], dtype=dtype).reshape(shape)
Expand Down Expand Up @@ -315,16 +327,14 @@ def test_virtualfile_from_matrix_slice(dtypes):
assert output == expected


def test_virtualfile_from_vectors_pandas(dtypes):
def test_virtualfile_from_vectors_pandas(dtypes_pandas):
"""
Pass vectors to a dataset using pandas.Series, checking both numpy and pyarrow
dtypes.
"""
size = 13
if find_spec("pyarrow") is not None:
dtypes.extend([f"{dtype}[pyarrow]" for dtype in dtypes])

for dtype in dtypes:
for dtype in dtypes_pandas:
data = pd.DataFrame(
data={
"x": np.arange(size),
Expand Down
29 changes: 29 additions & 0 deletions pygmt/tests/test_session_management.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""
Test the session management modules.
"""
import multiprocessing as mp
import os
from importlib import reload
from pathlib import Path

import pytest
from pygmt.clib import Session
Expand Down Expand Up @@ -57,3 +60,29 @@ def test_gmt_compat_6_is_applied(capsys):
# Make sure no global "gmt.conf" in the current directory
assert not os.path.exists("gmt.conf")
begin() # Restart the global session


def _gmt_func_wrapper(figname):
"""
A wrapper for running PyGMT scripts with multiprocessing.
Currently, we have to import pygmt and reload it in each process. Workaround from
https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875.
"""
import pygmt

reload(pygmt)
fig = pygmt.Figure()
fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg")
fig.savefig(figname)


def test_session_multiprocessing():
"""
Make sure that multiprocessing is supported if pygmt is re-imported.
"""
prefix = "test_session_multiprocessing"
with mp.Pool(2) as p:
p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"])
Path(f"{prefix}-1.png").unlink()
Path(f"{prefix}-2.png").unlink()

0 comments on commit 11cb76b

Please sign in to comment.