Skip to content

Commit

Permalink
fixed unconditional sampling reproducibility issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Ignacio Lopez-Francos authored and WuTheFWasThat committed Feb 20, 2019
1 parent 99af6d7 commit 2cf46d9
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ def sample_model(
temperature=1,
top_k=0,
):
np.random.seed(seed)
tf.set_random_seed(seed)

enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
Expand All @@ -31,6 +28,9 @@ def sample_model(
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

with tf.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
tf.set_random_seed(seed)

output = sample.sample_sequence(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
Expand Down

0 comments on commit 2cf46d9

Please sign in to comment.