diff --git a/pygmt/session_management.py b/pygmt/session_management.py index 4bb829835b9..750157679d4 100644 --- a/pygmt/session_management.py +++ b/pygmt/session_management.py @@ -1,7 +1,11 @@ """ Modern mode session management modules. """ +import os +import sys + from pygmt.clib import Session +from pygmt.helpers import unique_name def begin(): @@ -12,6 +16,10 @@ def begin(): Only meant to be used once for creating the global session. """ + # On Windows, need to set GMT_SESSION_NAME to a unique value + if sys.platform == "win32": + os.environ["GMT_SESSION_NAME"] = unique_name() + prefix = "pygmt-session" with Session() as lib: lib.call_module(module="begin", args=prefix) diff --git a/pygmt/src/meca.py b/pygmt/src/meca.py index dbe13d088af..98c8172717c 100644 --- a/pygmt/src/meca.py +++ b/pygmt/src/meca.py @@ -168,7 +168,7 @@ def convention_params(convention): ], "mt": ["mrr", "mtt", "mff", "mrt", "mrf", "mtf", "exponent"], "partial": ["strike1", "dip1", "strike2", "fault_type", "magnitude"], - "pricipal_axis": [ + "principal_axis": [ "t_value", "t_azimuth", "t_plunge", @@ -401,7 +401,7 @@ def meca( # noqa: PLR0912, PLR0913, PLR0915 # Convert spec to pandas.DataFrame unless it's a file if isinstance(spec, (dict, pd.DataFrame)): # spec is a dict or pd.DataFrame # determine convention from dict keys or pd.DataFrame column names - for conv in ["aki", "gcmt", "mt", "partial", "pricipal_axis"]: + for conv in ["aki", "gcmt", "mt", "partial", "principal_axis"]: if set(convention_params(conv)).issubset(set(spec.keys())): convention = conv break diff --git a/pygmt/tests/test_clib_virtualfiles.py b/pygmt/tests/test_clib_virtualfiles.py index 5c0e39d3467..669484a391d 100644 --- a/pygmt/tests/test_clib_virtualfiles.py +++ b/pygmt/tests/test_clib_virtualfiles.py @@ -34,6 +34,19 @@ def fixture_dtypes(): return "int8 int16 int32 int64 uint8 uint16 uint32 uint64 float32 float64".split() +@pytest.fixture(scope="module", name="dtypes_pandas") +def fixture_dtypes_pandas(dtypes): + """ + List of supported pandas dtypes. + """ + dtypes_pandas = dtypes.copy() + + if find_spec("pyarrow") is not None: + dtypes_pandas.extend([f"{dtype}[pyarrow]" for dtype in dtypes_pandas]) + + return tuple(dtypes_pandas) + + def test_virtual_file(dtypes): """ Test passing in data via a virtual file with a Dataset. @@ -248,11 +261,10 @@ def test_virtualfile_from_vectors_two_string_or_object_columns(dtype): assert output == expected -def test_virtualfile_from_vectors_transpose(): +def test_virtualfile_from_vectors_transpose(dtypes): """ Test transforming matrix columns to virtual file dataset. """ - dtypes = "float32 float64 int32 int64 uint32 uint64".split() shape = (7, 5) for dtype in dtypes: data = np.arange(shape[0] * shape[1], dtype=dtype).reshape(shape) @@ -315,16 +327,14 @@ def test_virtualfile_from_matrix_slice(dtypes): assert output == expected -def test_virtualfile_from_vectors_pandas(dtypes): +def test_virtualfile_from_vectors_pandas(dtypes_pandas): """ Pass vectors to a dataset using pandas.Series, checking both numpy and pyarrow dtypes. """ size = 13 - if find_spec("pyarrow") is not None: - dtypes.extend([f"{dtype}[pyarrow]" for dtype in dtypes]) - for dtype in dtypes: + for dtype in dtypes_pandas: data = pd.DataFrame( data={ "x": np.arange(size), diff --git a/pygmt/tests/test_session_management.py b/pygmt/tests/test_session_management.py index 079c2c4e02c..544c5f037de 100644 --- a/pygmt/tests/test_session_management.py +++ b/pygmt/tests/test_session_management.py @@ -1,7 +1,10 @@ """ Test the session management modules. """ +import multiprocessing as mp import os +from importlib import reload +from pathlib import Path import pytest from pygmt.clib import Session @@ -57,3 +60,29 @@ def test_gmt_compat_6_is_applied(capsys): # Make sure no global "gmt.conf" in the current directory assert not os.path.exists("gmt.conf") begin() # Restart the global session + + +def _gmt_func_wrapper(figname): + """ + A wrapper for running PyGMT scripts with multiprocessing. + + Currently, we have to import pygmt and reload it in each process. Workaround from + https://github.com/GenericMappingTools/pygmt/issues/217#issuecomment-754774875. + """ + import pygmt + + reload(pygmt) + fig = pygmt.Figure() + fig.basemap(region=[10, 70, -3, 8], projection="X8c/6c", frame="afg") + fig.savefig(figname) + + +def test_session_multiprocessing(): + """ + Make sure that multiprocessing is supported if pygmt is re-imported. + """ + prefix = "test_session_multiprocessing" + with mp.Pool(2) as p: + p.map(_gmt_func_wrapper, [f"{prefix}-1.png", f"{prefix}-2.png"]) + Path(f"{prefix}-1.png").unlink() + Path(f"{prefix}-2.png").unlink()