fixed unconditional sampling reproducibility issue

This commit is contained in:
Ignacio Lopez-Francos
2019-02-20 09:15:53 -08:00
committed by Jeff Wu
parent 99af6d7092
commit 2cf46d997d

View File

@ -17,9 +17,6 @@ def sample_model(
temperature=1,
top_k=0,
):
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
@ -31,6 +28,9 @@ def sample_model(
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess:
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],