fixed unconditional sampling reproducibility issue
This commit is contained in:
committed by
Jeff Wu
parent
99af6d7092
commit
2cf46d997d
@ -17,9 +17,6 @@ def sample_model(
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
):
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
enc = encoder.get_encoder(model_name)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
@ -31,6 +28,9 @@ def sample_model(
|
||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
|
Reference in New Issue
Block a user