diff --git a/README.md b/README.md index 9df579e..5690072 100644 --- a/README.md +++ b/README.md @@ -18,21 +18,28 @@ Install python packages: pip3 install -r requirements.txt ``` -## Sample generation +## Unconditional sample generation | WARNING: Samples are unfiltered and may contain offensive content. | | --- | To generate unconditional samples from the small model: ``` -python3 src/main.py | tee samples +python3 src/generate_unconditional_samples.py | tee samples ``` There are various flags for controlling the samples: ``` -python3 src/main.py --top_k 40 --temperature 0.7 | tee samples +python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee samples ``` -While we have not yet released GPT-2 itself, you can see some unconditional samples (with default settings of temperature 1 and no truncation) in `gpt2-samples.txt`. +While we have not yet released GPT-2 itself, you can see some unconditional samples from it (with default settings of temperature 1 and no truncation) in `gpt2-samples.txt`. + +## Conditional sample generation + +To give the model custom prompts, you can use: +``` +python3 src/interactive_conditional_samples.py +``` ## Future work diff --git a/src/encoder.py b/src/encoder.py index 285d643..5068cc6 100644 --- a/src/encoder.py +++ b/src/encoder.py @@ -93,16 +93,13 @@ class Encoder: self.cache[token] = word return word - def encode_text(self, text): + def encode(self, text): bpe_tokens = [] for token in re.findall(self.pat, text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) return bpe_tokens - def encode(self, texts): - return [self.encode_text(text) for text in texts] - def decode(self, tokens): text = ''.join([self.decoder[token] for token in tokens]) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) diff --git a/src/main.py b/src/generate_unconditional_samples.py similarity index 100% rename from src/main.py rename to src/generate_unconditional_samples.py diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py new file mode 100755 index 0000000..5038117 --- /dev/null +++ b/src/interactive_conditional_samples.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +import fire +import json +import os +import numpy as np +import tensorflow as tf + +from src import model, sample, encoder + +def interact_model( + model_name='117M', + seed=None, + nsamples=1, + batch_size=None, + length=None, + temperature=1, + top_k=0, +): + if batch_size is None: + batch_size = 1 + assert nsamples % batch_size == 0 + np.random.seed(seed) + tf.set_random_seed(seed) + + enc = encoder.get_encoder(model_name) + hparams = model.default_hparams() + with open(os.path.join('models', model_name, 'hparams.json')) as f: + hparams.override_from_dict(json.load(f)) + + if length is None: + length = hparams.n_ctx // 2 + elif length > hparams.n_ctx: + raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}") + + with tf.Session(graph=tf.Graph()) as sess: + context = tf.placeholder(tf.int32, [batch_size, None]) + output = sample.sample_sequence( + hparams=hparams, length=length, + context=context, + batch_size=batch_size, + temperature=temperature, top_k=top_k + )[:, 1:] + + saver = tf.train.Saver() + ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) + saver.restore(sess, ckpt) + + while True: + raw_text = input("Model prompt >>> ") + context_tokens = enc.encode(raw_text) + generated = 0 + for _ in range(nsamples // batch_size): + out = sess.run(output, feed_dict={ + context: [context_tokens for _ in range(batch_size)] + }) + for i in range(batch_size): + generated += 1 + text = enc.decode(out[i]) + print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) + print(f"{text}") + print("=" * 80) + +if __name__ == '__main__': + fire.Fire(interact_model) +