interact script for conditional samples
This commit is contained in:
15
README.md
15
README.md
@ -18,21 +18,28 @@ Install python packages:
|
|||||||
pip3 install -r requirements.txt
|
pip3 install -r requirements.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sample generation
|
## Unconditional sample generation
|
||||||
|
|
||||||
| WARNING: Samples are unfiltered and may contain offensive content. |
|
| WARNING: Samples are unfiltered and may contain offensive content. |
|
||||||
| --- |
|
| --- |
|
||||||
|
|
||||||
To generate unconditional samples from the small model:
|
To generate unconditional samples from the small model:
|
||||||
```
|
```
|
||||||
python3 src/main.py | tee samples
|
python3 src/generate_unconditional_samples.py | tee samples
|
||||||
```
|
```
|
||||||
There are various flags for controlling the samples:
|
There are various flags for controlling the samples:
|
||||||
```
|
```
|
||||||
python3 src/main.py --top_k 40 --temperature 0.7 | tee samples
|
python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee samples
|
||||||
```
|
```
|
||||||
|
|
||||||
While we have not yet released GPT-2 itself, you can see some unconditional samples (with default settings of temperature 1 and no truncation) in `gpt2-samples.txt`.
|
While we have not yet released GPT-2 itself, you can see some unconditional samples from it (with default settings of temperature 1 and no truncation) in `gpt2-samples.txt`.
|
||||||
|
|
||||||
|
## Conditional sample generation
|
||||||
|
|
||||||
|
To give the model custom prompts, you can use:
|
||||||
|
```
|
||||||
|
python3 src/interactive_conditional_samples.py
|
||||||
|
```
|
||||||
|
|
||||||
## Future work
|
## Future work
|
||||||
|
|
||||||
|
@ -93,16 +93,13 @@ class Encoder:
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
def encode_text(self, text):
|
def encode(self, text):
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
for token in re.findall(self.pat, text):
|
for token in re.findall(self.pat, text):
|
||||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||||
return bpe_tokens
|
return bpe_tokens
|
||||||
|
|
||||||
def encode(self, texts):
|
|
||||||
return [self.encode_text(text) for text in texts]
|
|
||||||
|
|
||||||
def decode(self, tokens):
|
def decode(self, tokens):
|
||||||
text = ''.join([self.decoder[token] for token in tokens])
|
text = ''.join([self.decoder[token] for token in tokens])
|
||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
|
66
src/interactive_conditional_samples.py
Executable file
66
src/interactive_conditional_samples.py
Executable file
@ -0,0 +1,66 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import fire
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from src import model, sample, encoder
|
||||||
|
|
||||||
|
def interact_model(
|
||||||
|
model_name='117M',
|
||||||
|
seed=None,
|
||||||
|
nsamples=1,
|
||||||
|
batch_size=None,
|
||||||
|
length=None,
|
||||||
|
temperature=1,
|
||||||
|
top_k=0,
|
||||||
|
):
|
||||||
|
if batch_size is None:
|
||||||
|
batch_size = 1
|
||||||
|
assert nsamples % batch_size == 0
|
||||||
|
np.random.seed(seed)
|
||||||
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
enc = encoder.get_encoder(model_name)
|
||||||
|
hparams = model.default_hparams()
|
||||||
|
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||||
|
hparams.override_from_dict(json.load(f))
|
||||||
|
|
||||||
|
if length is None:
|
||||||
|
length = hparams.n_ctx // 2
|
||||||
|
elif length > hparams.n_ctx:
|
||||||
|
raise ValueError(f"can't get samples longer than window size: {hparams.n_ctx}")
|
||||||
|
|
||||||
|
with tf.Session(graph=tf.Graph()) as sess:
|
||||||
|
context = tf.placeholder(tf.int32, [batch_size, None])
|
||||||
|
output = sample.sample_sequence(
|
||||||
|
hparams=hparams, length=length,
|
||||||
|
context=context,
|
||||||
|
batch_size=batch_size,
|
||||||
|
temperature=temperature, top_k=top_k
|
||||||
|
)[:, 1:]
|
||||||
|
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||||
|
saver.restore(sess, ckpt)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
raw_text = input("Model prompt >>> ")
|
||||||
|
context_tokens = enc.encode(raw_text)
|
||||||
|
generated = 0
|
||||||
|
for _ in range(nsamples // batch_size):
|
||||||
|
out = sess.run(output, feed_dict={
|
||||||
|
context: [context_tokens for _ in range(batch_size)]
|
||||||
|
})
|
||||||
|
for i in range(batch_size):
|
||||||
|
generated += 1
|
||||||
|
text = enc.decode(out[i])
|
||||||
|
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
|
||||||
|
print(f"{text}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
fire.Fire(interact_model)
|
||||||
|
|
Reference in New Issue
Block a user