Skip to content

Commit

Permalink
Only do numpy seeding if numpy seed is not fixed, fix regression intr…
Browse files Browse the repository at this point in the history
…oduced by 9f9e12f
  • Loading branch information
samuelbroscheit committed May 16, 2020
1 parent dffcb9e commit bdb078f
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions kge/job/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SLOTS = [0, 1, 2]
S, P, O = SLOTS


def _worker_init_fn(worker_num):
# ensure that NumPy uses different seeds at each worker
np.random.seed()
Expand Down Expand Up @@ -595,7 +596,9 @@ def _prepare(self):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.config.get("train.num_workers"),
worker_init_fn=_worker_init_fn,
worker_init_fn=_worker_init_fn
if self.config.get("random_seed.numpy") == -1
else None,
pin_memory=self.config.get("train.pin_memory"),
)

Expand Down Expand Up @@ -803,7 +806,9 @@ def _prepare(self):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.config.get("train.num_workers"),
worker_init_fn=_worker_init_fn,
worker_init_fn=_worker_init_fn
if self.config.get("random_seed.numpy") == -1
else None,
pin_memory=self.config.get("train.pin_memory"),
)

Expand Down Expand Up @@ -1038,7 +1043,9 @@ def _prepare(self):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.config.get("train.num_workers"),
worker_init_fn=_worker_init_fn,
worker_init_fn=_worker_init_fn
if self.config.get("random_seed.numpy") == -1
else None,
pin_memory=self.config.get("train.pin_memory"),
)

Expand Down

0 comments on commit bdb078f

Please sign in to comment.