Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change behaviour of dtype on equivalent sources #516

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Update tests for dtype in eqs
  • Loading branch information
santisoler committed Jun 19, 2024
commit ebc2eb88a65102c205a324006555c71c000fb04c
29 changes: 14 additions & 15 deletions harmonica/tests/test_eq_sources_cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,28 +343,29 @@ def test_dtype(
dtype,
):
"""
Test dtype argument on EquivalentSources
Test if predictions have the dtype passed as argument.
"""
# Define the points argument for EquivalentSources
points = None
if custom_points:
points = vd.grid_coordinates(region, spacing=300, extra_coords=-2e3)
# Define the points argument for EquivalentSources.fit()
if weights_none:
weights = None
points = (
vd.grid_coordinates(region, spacing=300, extra_coords=-2e3)
if custom_points
else None
)
# Define the weights argument for EquivalentSources.fit()
weights = weights if not weights_none else None
# Initialize and fit the equivalent sources
eqs = EquivalentSources(
damping=damping, points=points, block_size=block_size, dtype=dtype
)
eqs.fit(coordinates, data, weights)
# Make some predictions
# Ensure predictions have the expected dtype
prediction = eqs.predict(coordinates)
# Check data type of created objects
for coord in eqs.points_:
assert coord.dtype == np.dtype(dtype)
assert prediction.dtype == np.dtype(dtype)
# Check the data type of the source coefficients
# assert eqs.coefs_.dtype == np.dtype(dtype)
# Locations of sources should be the same dtype as the coordinates
for coord in eqs.points_:
assert coord.dtype == coordinates[0].dtype
# Sources' coefficients should be the same dtype as the coordinates
assert eqs.coefs_.dtype == coordinates[0].dtype


@pytest.mark.use_numba
Expand All @@ -381,8 +382,6 @@ def test_jacobian_dtype(region, dtype):
points = tuple(
p.ravel() for p in vd.grid_coordinates(region, shape=(6, 6), extra_coords=-2e3)
)
# Ravel the coordinates
coordinates = tuple(c.ravel() for c in coordinates)
# Initialize and fit the equivalent sources
eqs = EquivalentSources(points=points, dtype=dtype)
# Build jacobian matrix
Expand Down