Skip to content

Commit

Permalink
add check that weights have changed during training
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jun 25, 2024
1 parent 5bab62e commit 49bbe79
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
7 changes: 0 additions & 7 deletions tests/test_two_moons/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ def batch_size():
return 128


@pytest.fixture()
def inference_network():
from bayesflow.networks import CouplingFlow

return CouplingFlow()


@pytest.fixture()
def random_samples(batch_size, simulator):
return simulator.sample((batch_size,))
Expand Down
6 changes: 6 additions & 0 deletions tests/test_two_moons/test_two_moons.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import keras
import pytest

Expand All @@ -19,12 +20,17 @@ def test_fit(approximator, train_dataset, validation_dataset, batch_size):

approximator.build_from_data(train_dataset[0])

untrained_weights = copy.deepcopy(approximator.weights)
untrained_metrics = approximator.evaluate(validation_dataset, return_dict=True)

approximator.fit(train_dataset, epochs=20)

trained_weights = approximator.weights
trained_metrics = approximator.evaluate(validation_dataset, return_dict=True)

# check weights have changed during training
assert any([keras.ops.any(~keras.ops.isclose(u, t)) for u, t in zip(untrained_weights, trained_weights)])

assert isinstance(untrained_metrics, dict)
assert isinstance(trained_metrics, dict)

Expand Down

0 comments on commit 49bbe79

Please sign in to comment.