From 2cf46d997ded2c8627cb6407785d3fbbb5adedec Mon Sep 17 00:00:00 2001 From: Ignacio Lopez-Francos Date: Wed, 20 Feb 2019 09:15:53 -0800 Subject: [PATCH] fixed unconditional sampling reproducibility issue --- src/generate_unconditional_samples.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index 2485149..658ca78 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -17,9 +17,6 @@ def sample_model( temperature=1, top_k=0, ): - np.random.seed(seed) - tf.set_random_seed(seed) - enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: @@ -31,6 +28,9 @@ def sample_model( raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) with tf.Session(graph=tf.Graph()) as sess: + np.random.seed(seed) + tf.set_random_seed(seed) + output = sample.sample_sequence( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'],