Compare commits
19 Commits
new_readme
...
master
Author | SHA1 | Date | |
---|---|---|---|
|
a74da5d99a | ||
|
0574c5708b | ||
|
03fce0a080 | ||
|
0f97760ebe | ||
|
ebdba20a19 | ||
|
d98291d2ae | ||
|
fbae7db92a | ||
|
ac5d52295f | ||
|
f35fa1d920 | ||
|
cb415376c3 | ||
|
e9378792c4 | ||
|
41a6793dc6 | ||
|
c0859d7523 | ||
|
e5c5054474 | ||
|
dd75299dfe | ||
|
b5ef71a922 | ||
|
0503b1b249 | ||
|
d14501aade | ||
|
86378284e1 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
__pycache__
|
||||
.mypy_cache/
|
||||
models/
|
||||
|
@@ -6,7 +6,7 @@
|
||||
|
||||
* **[Margaret Mitchell et al](https://arxiv.org/abs/1810.03993)**
|
||||
|
||||
Our [usage](./readme#usage) writeup was loosely inspired by the paper
|
||||
Our [usage](./README.md#usage) writeup was loosely inspired by the paper
|
||||
[Model Cards for Model Reporting](https://arxiv.org/abs/1810.03993)
|
||||
and related conversations with some of the authors.
|
||||
|
||||
|
@@ -27,7 +27,10 @@ pip3 install -r requirements.txt
|
||||
|
||||
Download the model data
|
||||
```
|
||||
python3 download_model.py 117M
|
||||
python3 download_model.py 124M
|
||||
python3 download_model.py 355M
|
||||
python3 download_model.py 774M
|
||||
python3 download_model.py 1558M
|
||||
```
|
||||
|
||||
## Docker Installation
|
||||
|
@@ -5,4 +5,7 @@ RUN mkdir /gpt-2
|
||||
WORKDIR /gpt-2
|
||||
ADD . /gpt-2
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN python3 download_model.py 117M
|
||||
RUN python3 download_model.py 124M
|
||||
RUN python3 download_model.py 355M
|
||||
RUN python3 download_model.py 774M
|
||||
RUN python3 download_model.py 1558M
|
||||
|
@@ -14,4 +14,7 @@ RUN mkdir /gpt-2
|
||||
WORKDIR /gpt-2
|
||||
ADD . /gpt-2
|
||||
RUN pip3 install -r requirements.txt
|
||||
RUN python3 download_model.py 117M
|
||||
RUN python3 download_model.py 124M
|
||||
RUN python3 download_model.py 355M
|
||||
RUN python3 download_model.py 774M
|
||||
RUN python3 download_model.py 1558M
|
||||
|
37
LICENSE
37
LICENSE
@@ -1,21 +1,24 @@
|
||||
MIT License
|
||||
Modified MIT License
|
||||
|
||||
Copyright (c) 2019 OpenAI
|
||||
Software Copyright (c) 2019 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please.
|
||||
We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so,
|
||||
subject to the following conditions:
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
The above copyright notice and this permission notice shall be included
|
||||
in all copies or substantial portions of the Software.
|
||||
The above copyright notice and this permission notice need not be included
|
||||
with content created by the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
|
||||
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
||||
BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
31
README.md
31
README.md
@@ -1,24 +1,30 @@
|
||||
**Status:** Archive (code is provided as-is, no updates expected)
|
||||
|
||||
# gpt-2
|
||||
|
||||
Code and samples from the paper ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf).
|
||||
Code and models from the paper ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf).
|
||||
|
||||
For now, we have only released a smaller (117M parameter) version of GPT-2.
|
||||
You can read about GPT-2 and its staged release in our [original blog post](https://blog.openai.com/better-language-models/), [6 month follow-up post](https://openai.com/blog/gpt-2-6-month-follow-up/), and [final post](https://www.openai.com/blog/gpt-2-1-5b-release/).
|
||||
|
||||
See more details in our [blog post](https://blog.openai.com/better-language-models/).
|
||||
We have also [released a dataset](https://github.com/openai/gpt-2-output-dataset) for researchers to study their behaviors.
|
||||
|
||||
<sup>*</sup> *Note that our original parameter counts were wrong due to an error (in our previous blog posts and paper). Thus you may have seen small referred to as 117M and medium referred to as 345M.*
|
||||
|
||||
## Usage
|
||||
|
||||
This repository is meant to be a starting point for researchers and engineers to experiment with GPT-2-117M. While GPT-2-117M is less proficient than GPT-2-1.5B, it is useful for a wide range of research and applications which could also apply to larger models.
|
||||
This repository is meant to be a starting point for researchers and engineers to experiment with GPT-2.
|
||||
|
||||
For basic information, see our [model card](./model_card.md).
|
||||
|
||||
### Some caveats
|
||||
|
||||
- GPT-2-117M robustness and worst case behaviors are not well-understood. As with any machine-learned model, carefully evaluate GPT-2-117M for your use case, especially if used without fine-tuning or in safety-critical applications where reliability is important.
|
||||
- The dataset our GPT-2-117M was trained on contains many texts with [biases](https://twitter.com/TomerUllman/status/1101485289720242177) and factual inaccuracies, and thus GPT-2-117M is likely to be biased and inaccurate as well.
|
||||
- GPT-2 models' robustness and worst case behaviors are not well-understood. As with any machine-learned model, carefully evaluate GPT-2 for your use case, especially if used without fine-tuning or in safety-critical applications where reliability is important.
|
||||
- The dataset our GPT-2 models were trained on contains many texts with [biases](https://twitter.com/TomerUllman/status/1101485289720242177) and factual inaccuracies, and thus GPT-2 models are likely to be biased and inaccurate as well.
|
||||
- To avoid having samples mistaken as human-written, we recommend clearly labeling samples as synthetic before wide dissemination. Our models are often incoherent or inaccurate in subtle ways, which takes more than a quick read for a human to notice.
|
||||
|
||||
### Work with us
|
||||
|
||||
Please [let us know](mailto:languagequestions@openai.com) if you’re doing interesting research with or working on applications of GPT-2-117M! We’re especially interested in hearing from and potentially working with those who are studying
|
||||
Please [let us know](mailto:languagequestions@openai.com) if you’re doing interesting research with or working on applications of GPT-2! We’re especially interested in hearing from and potentially working with those who are studying
|
||||
- Potential malicious use cases and defenses against them (e.g. the detectability of synthetic text)
|
||||
- The extent of problematic content (e.g. bias) being baked into the models and effective mitigations
|
||||
|
||||
@@ -30,15 +36,6 @@ See [DEVELOPERS.md](./DEVELOPERS.md)
|
||||
|
||||
See [CONTRIBUTORS.md](./CONTRIBUTORS.md)
|
||||
|
||||
## GPT-2 samples
|
||||
|
||||
| WARNING: Samples are unfiltered and may contain offensive content. |
|
||||
| --- |
|
||||
|
||||
While we have not yet released GPT-2 itself, you can see some samples from it in the `gpt-2-samples` folder.
|
||||
We show unconditional samples with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40.
|
||||
We show conditional samples, with contexts drawn from `WebText`'s test set, with default settings (temperature 1 and no truncation), with temperature 0.7, and with truncation with top_k 40.
|
||||
|
||||
## Citation
|
||||
|
||||
Please use the following bibtex entry:
|
||||
@@ -58,4 +55,4 @@ We are still considering release of the larger models.
|
||||
|
||||
## License
|
||||
|
||||
[MIT](./LICENSE)
|
||||
[Modified MIT](./LICENSE)
|
||||
|
1000
domains.txt
Normal file
1000
domains.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print('You must enter the model name as a parameter, e.g.: download_model.py 117M')
|
||||
print('You must enter the model name as a parameter, e.g.: download_model.py 124M')
|
||||
sys.exit(1)
|
||||
|
||||
model = sys.argv[1]
|
||||
@@ -12,10 +12,11 @@ model = sys.argv[1]
|
||||
subdir = os.path.join('models', model)
|
||||
if not os.path.exists(subdir):
|
||||
os.makedirs(subdir)
|
||||
subdir = subdir.replace('\\','/') # needed for Windows
|
||||
|
||||
for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']:
|
||||
|
||||
r = requests.get("https://storage.googleapis.com/gpt-2/" + subdir + "/" + filename, stream=True)
|
||||
r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True)
|
||||
|
||||
with open(os.path.join(subdir, filename), 'wb') as f:
|
||||
file_size = int(r.headers["content-length"])
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
69
model_card.md
Normal file
69
model_card.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# GPT-2 model card
|
||||
|
||||
Last updated: November 2019
|
||||
|
||||
Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we’re providing some accompanying information about the GPT-2 family of models we're releasing.
|
||||
|
||||
## Model Details.
|
||||
|
||||
This model was developed by researchers at OpenAI to help us understand how the capabilities of language model capabilities scale as a function of the size of the models (by parameter count) combined with very large internet-scale datasets (WebText).
|
||||
|
||||
### Model date
|
||||
|
||||
February 2019, trained on data that cuts off at the end of 2017.
|
||||
|
||||
### Model type
|
||||
|
||||
Language model
|
||||
|
||||
### Model version
|
||||
|
||||
1.5 billion parameters: the fourth and largest GPT-2 version. We have also released 124 million, 355 million, and 774 million parameter models.
|
||||
|
||||
### Paper or other resource for more information
|
||||
[Blog post](https://openai.com/blog/better-language-models/) and [paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)
|
||||
|
||||
### Where to send questions or comments about the model
|
||||
Please use this [Google Form](https://forms.gle/A7WBSbTY2EkKdroPA)
|
||||
|
||||
## Intended Uses:
|
||||
|
||||
### Primary intended uses
|
||||
|
||||
The primary intended users of these models are *AI researchers and practitioners*.
|
||||
|
||||
We primarily imagine these language models will be used by researchers to better understand the behaviors, capabilities, biases, and constraints of large-scale generative language models.
|
||||
|
||||
### Secondary uses
|
||||
|
||||
Here are some secondary use cases we believe are likely:
|
||||
|
||||
- **Writing assistance**: Grammar assistance, autocompletion (for normal prose or code)
|
||||
- **Creative writing and art**: exploring the generation of creative, fictional texts; aiding creation of poetry and other literary art.
|
||||
- **Entertainment**: Creation of games, chat bots, and amusing generations.
|
||||
|
||||
### Out-of-scope use cases
|
||||
|
||||
Because large-scale language models like GPT-2 do not distinguish fact from fiction, we don’t support use-cases that require the generated text to be true.
|
||||
|
||||
Additionally, language models like GPT-2 reflect the biases inherent to the systems they were trained on, so we do not recommend that they be deployed into systems that interact with humans unless the deployers first carry out a study of biases relevant to the intended use-case. We found no statistically significant difference in gender, race, and religious bias probes between 774M and 1.5B, implying all versions of GPT-2 should be approached with similar levels of caution around use cases that are sensitive to biases around human attributes.
|
||||
|
||||
## Evaluation Data
|
||||
|
||||
### Datasets
|
||||
|
||||
This model was trained on (and evaluated against) WebText, a dataset consisting of the text contents of 45 million links posted by users of the ‘Reddit’ social network. WebText is made of data derived from outbound links from Reddit and does not consist of data taken directly from Reddit itself. Before generating the dataset we used a blocklist to ensure we didn’t sample from a variety of subreddits which contain sexually explicit or otherwise offensive content.
|
||||
|
||||
To get a sense of the data that went into GPT-2, we’ve [published a list](domains.txt) of the top 1,000 domains present in WebText and their frequency. The top 15 domains by volume in WebText are: Google, Archive, Blogspot, GitHub, NYTimes, Wordpress, Washington Post, Wikia, BBC, The Guardian, eBay, Pastebin, CNN, Yahoo!, and the Huffington Post.
|
||||
|
||||
### Motivation
|
||||
|
||||
The motivation behind WebText was to create an Internet-scale, heterogeneous dataset that we could use to test large-scale language models against. WebText was (and is) intended to be primarily for research purposes rather than production purposes.
|
||||
|
||||
### Caveats and Recommendations
|
||||
|
||||
Because GPT-2 is an internet-scale language model, it’s currently difficult to know what disciplined testing procedures can be applied to it to fully understand its capabilities and how the data it is trained on influences its vast range of outputs. We recommend researchers investigate these aspects of the model and share their results.
|
||||
|
||||
Additionally, as indicated in our discussion of issues relating to potential misuse of the model, it remains unclear what the long-term dynamics are of detecting outputs from these models. We conducted [in-house automated ML-based detection research](https://github.com/openai/gpt-2-output-dataset/tree/master/detector) using simple classifiers, zero shot, and fine-tuning methods. Our fine-tuned detector model reached accuracy levels of approximately 95%. However, no one detection method is a panacea; automated ML-based detection, human detection, human-machine teaming, and metadata-based detection are all methods that can be combined for more confident classification. Developing better approaches to detection today will give us greater intuitions when thinking about future models and could help us understand ahead of time if detection methods will eventually become ineffective.
|
||||
|
||||
|
@@ -105,10 +105,10 @@ class Encoder:
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def get_encoder(model_name):
|
||||
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
|
||||
def get_encoder(model_name, models_dir):
|
||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||
encoder = json.load(f)
|
||||
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
return Encoder(
|
||||
|
@@ -9,17 +9,19 @@ import tensorflow as tf
|
||||
import model, sample, encoder
|
||||
|
||||
def sample_model(
|
||||
model_name='117M',
|
||||
model_name='124M',
|
||||
seed=None,
|
||||
nsamples=0,
|
||||
batch_size=1,
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Run the sample_model
|
||||
:model_name=117M : String, which model to use
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to
|
||||
reproduce results
|
||||
:nsamples=0 : Number of samples to return, if 0, continues to
|
||||
@@ -35,10 +37,13 @@ def sample_model(
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
enc = encoder.get_encoder(model_name)
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
@@ -54,11 +59,11 @@ def sample_model(
|
||||
hparams=hparams, length=length,
|
||||
start_token=enc.encoder['<|endoftext|>'],
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)[:, 1:]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
generated = 0
|
||||
|
@@ -9,17 +9,19 @@ import tensorflow as tf
|
||||
import model, sample, encoder
|
||||
|
||||
def interact_model(
|
||||
model_name='117M',
|
||||
model_name='124M',
|
||||
seed=None,
|
||||
nsamples=1,
|
||||
batch_size=1,
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
top_p=1,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Interactively run the model
|
||||
:model_name=117M : String, which model to use
|
||||
:model_name=124M : String, which model to use
|
||||
:seed=None : Integer seed for random number generators, fix seed to reproduce
|
||||
results
|
||||
:nsamples=1 : Number of samples to return total
|
||||
@@ -34,14 +36,17 @@ def interact_model(
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
assert nsamples % batch_size == 0
|
||||
|
||||
enc = encoder.get_encoder(model_name)
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
@@ -57,11 +62,11 @@ def interact_model(
|
||||
hparams=hparams, length=length,
|
||||
context=context,
|
||||
batch_size=batch_size,
|
||||
temperature=temperature, top_k=top_k
|
||||
temperature=temperature, top_k=top_k, top_p=top_p
|
||||
)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
while True:
|
||||
|
@@ -22,7 +22,25 @@ def top_k_logits(logits, k):
|
||||
)
|
||||
|
||||
|
||||
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
|
||||
def top_p_logits(logits, p):
|
||||
"""Nucleus sampling"""
|
||||
batch, _ = logits.shape.as_list()
|
||||
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
|
||||
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||
indices = tf.stack([
|
||||
tf.range(0, batch),
|
||||
# number of indices to include
|
||||
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
|
||||
], axis=-1)
|
||||
min_values = tf.gather_nd(sorted_logits, indices)
|
||||
return tf.where(
|
||||
logits < min_values,
|
||||
tf.ones_like(logits) * -1e10,
|
||||
logits,
|
||||
)
|
||||
|
||||
|
||||
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
else:
|
||||
@@ -41,36 +59,34 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
|
||||
}
|
||||
|
||||
with tf.name_scope('sample_sequence'):
|
||||
# Don't feed the last context token -- leave that to the loop below
|
||||
# TODO: Would be slightly faster if we called step on the entire context,
|
||||
# rather than leaving the last token transformer calculation to the while loop.
|
||||
context_output = step(hparams, context[:, :-1])
|
||||
|
||||
def body(past, prev, output):
|
||||
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
|
||||
next_outputs = step(hparams, prev, past=past)
|
||||
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
logits = top_p_logits(logits, p=top_p)
|
||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||
return [
|
||||
tf.concat([past, next_outputs['presents']], axis=-2),
|
||||
tf.squeeze(samples, axis=[1]),
|
||||
tf.concat([output, samples], axis=1),
|
||||
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
||||
samples,
|
||||
tf.concat([output, samples], axis=1)
|
||||
]
|
||||
|
||||
past, prev, output = body(None, context, context)
|
||||
|
||||
def cond(*args):
|
||||
return True
|
||||
|
||||
_, _, tokens = tf.while_loop(
|
||||
cond=cond, body=body,
|
||||
maximum_iterations=length,
|
||||
maximum_iterations=length - 1,
|
||||
loop_vars=[
|
||||
context_output['presents'],
|
||||
context[:, -1],
|
||||
context,
|
||||
past,
|
||||
prev,
|
||||
output
|
||||
],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
||||
tf.TensorShape([batch_size]),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
],
|
||||
back_prop=False,
|
||||
|
Reference in New Issue
Block a user