fixed seed arg to ensure reproducibility in conditional-samples model

This commit is contained in:
Ignacio Lopez-Francos
2019-02-20 12:18:19 -08:00
committed by Jeff Wu
parent 2cf46d997d
commit 946facf551

View File

@ -20,8 +20,6 @@ def interact_model(
if batch_size is None: if batch_size is None:
batch_size = 1 batch_size = 1
assert nsamples % batch_size == 0 assert nsamples % batch_size == 0
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name) enc = encoder.get_encoder(model_name)
hparams = model.default_hparams() hparams = model.default_hparams()
@ -35,6 +33,8 @@ def interact_model(
with tf.Session(graph=tf.Graph()) as sess: with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None]) context = tf.placeholder(tf.int32, [batch_size, None])
np.random.seed(seed)
tf.set_random_seed(seed)
output = sample.sample_sequence( output = sample.sample_sequence(
hparams=hparams, length=length, hparams=hparams, length=length,
context=context, context=context,