Skip to content

Commit

Permalink
Improve performance by avoiding loading the GMT library repeatedly (#…
Browse files Browse the repository at this point in the history
…2930)

Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 committed Jan 2, 2024
1 parent 88ab1ca commit b561e9d
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@
np.datetime64: "GMT_DATETIME",
}

# Load the GMT library outside the Session class to avoid repeated loading.
_libgmt = load_libgmt()


class Session:
"""
Expand Down Expand Up @@ -308,7 +311,7 @@ def get_libgmt_func(self, name, argtypes=None, restype=None):
<class 'ctypes.CDLL.__init__.<locals>._FuncPtr'>
"""
if not hasattr(self, "_libgmt"):
self._libgmt = load_libgmt()
self._libgmt = _libgmt
function = getattr(self._libgmt, name)
if argtypes is not None:
function.argtypes = argtypes
Expand Down
39 changes: 39 additions & 0 deletions pygmt/tests/test_clib_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import pytest
from pygmt.clib.loading import check_libgmt, clib_full_names, clib_names, load_libgmt
from pygmt.clib.session import Session
from pygmt.exceptions import GMTCLibError, GMTCLibNotFoundError, GMTOSError


Expand Down Expand Up @@ -207,6 +208,44 @@ def test_brokenlib_brokenlib_workinglib(self):
assert check_libgmt(load_libgmt(lib_fullnames=lib_fullnames)) is None


class TestLibgmtCount:
"""
Test that the GMT library is not repeatedly loaded in every session.
"""

loaded_libgmt = load_libgmt() # Load the GMT library and reuse it when necessary
counter = 0 # Global counter for how many times ctypes.CDLL is called

def _mock_ctypes_cdll_return(self, libname): # noqa: ARG002
"""
Mock ctypes.CDLL to count how many times the function is called.
If ctypes.CDLL is called, the counter increases by one.
"""
self.counter += 1 # Increase the counter
return self.loaded_libgmt

def test_libgmt_load_counter(self, monkeypatch):
"""
Make sure that the GMT library is not loaded in every session.
"""
# Monkeypatch the ctypes.CDLL function
monkeypatch.setattr(ctypes, "CDLL", self._mock_ctypes_cdll_return)

# Create two sessions and check the global counter
with Session() as lib:
_ = lib
with Session() as lib:
_ = lib
assert self.counter == 0 # ctypes.CDLL is not called after two sessions.

# Explicitly calling load_libgmt to make sure the mock function is correct
load_libgmt()
assert self.counter == 1
load_libgmt()
assert self.counter == 2


###############################################################################
# Test clib_full_names
@pytest.fixture(scope="module", name="gmt_lib_names")
Expand Down

0 comments on commit b561e9d

Please sign in to comment.