fixed unconditional sampling reproducibility issue
This commit is contained in:
committed by
Jeff Wu
parent
99af6d7092
commit
2cf46d997d
@ -17,9 +17,6 @@ def sample_model(
|
|||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
):
|
):
|
||||||
np.random.seed(seed)
|
|
||||||
tf.set_random_seed(seed)
|
|
||||||
|
|
||||||
enc = encoder.get_encoder(model_name)
|
enc = encoder.get_encoder(model_name)
|
||||||
hparams = model.default_hparams()
|
hparams = model.default_hparams()
|
||||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||||
@ -31,6 +28,9 @@ def sample_model(
|
|||||||
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||||
|
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
output = sample.sample_sequence(
|
output = sample.sample_sequence(
|
||||||
hparams=hparams, length=length,
|
hparams=hparams, length=length,
|
||||||
start_token=enc.encoder['<|endoftext|>'],
|
start_token=enc.encoder['<|endoftext|>'],
|
||||||
|
Reference in New Issue
Block a user