Files
gpt-2/src/generate_unconditional_samples.py
2019-02-14 10:40:41 -08:00

57 lines
1.5 KiB
Python
Executable File

#!/usr/bin/env python3
import fire
import json
import os
import numpy as np
import tensorflow as tf
from src import model, sample, encoder
def sample_model(
model_name='117M',
seed=None,
nsamples=0,
batch_size=1,
length=None,
temperature=1,
top_k=0,
):
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx
elif length > hparams.n_ctx:
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}")
with tf.Session(graph=tf.Graph()) as sess:
output = sample.sample_sequence(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=batch_size,
temperature=temperature, top_k=top_k
)[:, 1:]
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
generated = 0
while nsamples == 0 or generated < nsamples:
out = sess.run(output)
for i in range(batch_size):
generated += batch_size
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(f"{text}")
if __name__ == '__main__':
fire.Fire(sample_model)