Files
gpt-2/src/interactive_conditional_samples.py

70 lines
2.1 KiB
Python
Raw Normal View History

2019-02-10 20:22:00 -08:00
#!/usr/bin/env python3
import fire
import json
import os
import numpy as np
import tensorflow as tf
from src import model, sample, encoder
def interact_model(
2019-02-10 20:22:00 -08:00
model_name='117M',
seed=None,
nsamples=1,
batch_size=None,
2019-02-10 20:22:00 -08:00
length=None,
temperature=1,
top_k=0,
):
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
2019-02-10 20:22:00 -08:00
np.random.seed(seed)
tf.set_random_seed(seed)
enc = encoder.get_encoder(model_name)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
2019-02-10 20:22:00 -08:00
elif length > hparams.n_ctx:
2019-02-14 11:34:14 -08:00
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
2019-02-10 20:22:00 -08:00
with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None])
2019-02-10 20:22:00 -08:00
output = sample.sample_sequence(
hparams=hparams, length=length,
context=context,
2019-02-10 20:22:00 -08:00
batch_size=batch_size,
temperature=temperature, top_k=top_k
2019-02-14 11:34:14 -08:00
)
2019-02-10 20:22:00 -08:00
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
saver.restore(sess, ckpt)
while True:
raw_text = input("Model prompt >>> ")
2019-02-14 11:34:14 -08:00
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0
for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)]
2019-02-14 11:34:14 -08:00
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
2019-02-14 11:34:14 -08:00
print(text)
print("=" * 80)
2019-02-10 20:22:00 -08:00
if __name__ == '__main__':
fire.Fire(interact_model)
2019-02-10 20:22:00 -08:00