interact script for conditional samples
This commit is contained in:
56
src/generate_unconditional_samples.py
Executable file
56
src/generate_unconditional_samples.py
Executable file
@@ -0,0 +1,56 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import fire
|
||||
import json
|
||||
import os
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from src import model, sample, encoder
|
||||
|
||||
def sample_model(
|
||||
model_name='117M',
|
||||
seed=None,
|
||||
nsamples=0,
|
||||
batch_size=1,
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
):
|
||||
np.random.seed(seed)
|
||||
tf.set_random_seed(seed)
|
||||
|
||||
enc = encoder.get_encoder(model_name)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
length = hparams.n_ctx
|
||||
elif length > hparams.n_ctx:
|
||||
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}")
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
output = sample.sample_sequence(
|
||||
hparams=hparams, length=length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k
|
||||
)[:, 1:]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
generated = 0
|
||||
while nsamples == 0 or generated < nsamples:
|
||||
out = sess.run(output)
|
||||
for i in range(batch_size):
|
||||
generated += batch_size
|
||||
text = enc.decode(out[i])
|
||||
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||
print(f"{text}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(sample_model)
|
||||
|
Reference in New Issue
Block a user