fix bug and remove f strings

This commit is contained in:
Jeff Wu
2019-02-14 11:34:14 -08:00
parent e33295b4b5
commit 7cdac144c3
2 changed files with 9 additions and 6 deletions

View File

@ -28,7 +28,7 @@ def sample_model(
if length is None: if length is None:
length = hparams.n_ctx length = hparams.n_ctx
elif length > hparams.n_ctx: elif length > hparams.n_ctx:
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess: with tf.Session(graph=tf.Graph()) as sess:
output = sample.sample_sequence( output = sample.sample_sequence(
@ -49,7 +49,7 @@ def sample_model(
generated += batch_size generated += batch_size
text = enc.decode(out[i]) text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(f"{text}") print(text)
if __name__ == '__main__': if __name__ == '__main__':
fire.Fire(sample_model) fire.Fire(sample_model)

View File

@ -31,7 +31,7 @@ def interact_model(
if length is None: if length is None:
length = hparams.n_ctx // 2 length = hparams.n_ctx // 2
elif length > hparams.n_ctx: elif length > hparams.n_ctx:
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
with tf.Session(graph=tf.Graph()) as sess: with tf.Session(graph=tf.Graph()) as sess:
context = tf.placeholder(tf.int32, [batch_size, None]) context = tf.placeholder(tf.int32, [batch_size, None])
@ -40,7 +40,7 @@ def interact_model(
context=context, context=context,
batch_size=batch_size, batch_size=batch_size,
temperature=temperature, top_k=top_k temperature=temperature, top_k=top_k
)[:, 1:] )
saver = tf.train.Saver() saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
@ -48,17 +48,20 @@ def interact_model(
while True: while True:
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text) context_tokens = enc.encode(raw_text)
generated = 0 generated = 0
for _ in range(nsamples // batch_size): for _ in range(nsamples // batch_size):
out = sess.run(output, feed_dict={ out = sess.run(output, feed_dict={
context: [context_tokens for _ in range(batch_size)] context: [context_tokens for _ in range(batch_size)]
}) })[:, len(context_tokens):]
for i in range(batch_size): for i in range(batch_size):
generated += 1 generated += 1
text = enc.decode(out[i]) text = enc.decode(out[i])
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(f"{text}") print(text)
print("=" * 80) print("=" * 80)
if __name__ == '__main__': if __name__ == '__main__':