fix bug and remove f strings
This commit is contained in:
@ -28,7 +28,7 @@ def sample_model(
|
|||||||
if length is None:
|
if length is None:
|
||||||
length = hparams.n_ctx
|
length = hparams.n_ctx
|
||||||
elif length > hparams.n_ctx:
|
elif length > hparams.n_ctx:
|
||||||
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}")
|
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||||
|
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
output = sample.sample_sequence(
|
output = sample.sample_sequence(
|
||||||
@ -49,7 +49,7 @@ def sample_model(
|
|||||||
generated += batch_size
|
generated += batch_size
|
||||||
text = enc.decode(out[i])
|
text = enc.decode(out[i])
|
||||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
print(f"{text}")
|
print(text)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
fire.Fire(sample_model)
|
fire.Fire(sample_model)
|
||||||
|
@ -31,7 +31,7 @@ def interact_model(
|
|||||||
if length is None:
|
if length is None:
|
||||||
length = hparams.n_ctx // 2
|
length = hparams.n_ctx // 2
|
||||||
elif length > hparams.n_ctx:
|
elif length > hparams.n_ctx:
|
||||||
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}")
|
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
|
||||||
|
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
context = tf.placeholder(tf.int32, [batch_size, None])
|
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||||
@ -40,7 +40,7 @@ def interact_model(
|
|||||||
context=context,
|
context=context,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
temperature=temperature, top_k=top_k
|
temperature=temperature, top_k=top_k
|
||||||
)[:, 1:]
|
)
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||||
@ -48,17 +48,20 @@ def interact_model(
|
|||||||
|
|
||||||
while True:
|
while True:
|
||||||
raw_text = input("Model prompt >>> ")
|
raw_text = input("Model prompt >>> ")
|
||||||
|
while not raw_text:
|
||||||
|
print('Prompt should not be empty!')
|
||||||
|
raw_text = input("Model prompt >>> ")
|
||||||
context_tokens = enc.encode(raw_text)
|
context_tokens = enc.encode(raw_text)
|
||||||
generated = 0
|
generated = 0
|
||||||
for _ in range(nsamples // batch_size):
|
for _ in range(nsamples // batch_size):
|
||||||
out = sess.run(output, feed_dict={
|
out = sess.run(output, feed_dict={
|
||||||
context: [context_tokens for _ in range(batch_size)]
|
context: [context_tokens for _ in range(batch_size)]
|
||||||
})
|
})[:, len(context_tokens):]
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
generated += 1
|
generated += 1
|
||||||
text = enc.decode(out[i])
|
text = enc.decode(out[i])
|
||||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
print(f"{text}")
|
print(text)
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Reference in New Issue
Block a user