fixed seed arg to ensure reproducibility in conditional-samples model
This commit is contained in:
committed by
Jeff Wu
parent
2cf46d997d
commit
946facf551
@ -20,8 +20,6 @@ def interact_model(
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
assert nsamples % batch_size == 0
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
enc = encoder.get_encoder(model_name)
|
||||
hparams = model.default_hparams()
|
||||
@ -35,6 +33,8 @@ def interact_model(
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
context=context,
|
||||
|
Reference in New Issue
Block a user