#!/usr/bin/env python3 import fire import json import os import numpy as np import tensorflow as tf from src import model, sample, encoder def sample_model( model_name='117M', seed=None, nsamples=0, batch_size=1, length=None, temperature=1, top_k=0, ): np.random.seed(seed) tf.set_random_seed(seed) enc = encoder.get_encoder(model_name) hparams = model.default_hparams() with open(os.path.join('models', model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if length is None: length = hparams.n_ctx elif length > hparams.n_ctx: raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") with tf.Session(graph=tf.Graph()) as sess: output = sample.sample_sequence( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'], batch_size=batch_size, temperature=temperature, top_k=top_k )[:, 1:] saver = tf.train.Saver() ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) saver.restore(sess, ckpt) generated = 0 while nsamples == 0 or generated < nsamples: out = sess.run(output) for i in range(batch_size): generated += batch_size text = enc.decode(out[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(f"{text}") if __name__ == '__main__': fire.Fire(sample_model)