fixed seed arg to ensure reproducibility in conditional-samples model
This commit is contained in:
committed by
Jeff Wu
parent
2cf46d997d
commit
946facf551
@ -20,8 +20,6 @@ def interact_model(
|
|||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
assert nsamples % batch_size == 0
|
assert nsamples % batch_size == 0
|
||||||
np.random.seed(seed)
|
|
||||||
tf.set_random_seed(seed)
|
|
||||||
|
|
||||||
enc = encoder.get_encoder(model_name)
|
enc = encoder.get_encoder(model_name)
|
||||||
hparams = model.default_hparams()
|
hparams = model.default_hparams()
|
||||||
@ -35,6 +33,8 @@ def interact_model(
|
|||||||
|
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
context = tf.placeholder(tf.int32, [batch_size, None])
|
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
output = sample.sample_sequence(
|
output = sample.sample_sequence(
|
||||||
hparams=hparams, length=length,
|
hparams=hparams, length=length,
|
||||||
context=context,
|
context=context,
|
||||||
|
Reference in New Issue
Block a user