First commit

This commit is contained in:
Jeff Wu
2019-02-10 20:22:00 -08:00
committed by Jeff Wu
commit c2dae27c10
7 changed files with 463 additions and 0 deletions

79
src/sample.py Normal file
View File

@ -0,0 +1,79 @@
import tensorflow as tf
import model
def top_k_logits(logits, k):
if k == 0:
# no truncation
return logits
def _top_k():
values, _ = tf.nn.top_k(logits, k=k)
min_values = values[:, -1, tf.newaxis]
return tf.where(
logits < min_values,
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
logits,
)
return tf.cond(
tf.equal(k, 0),
lambda: logits,
lambda: _top_k(),
)
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = tf.fill([batch_size, 1], start_token)
def step(hparams, tokens, past=None):
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
logits = lm_output['logits'][:, :, :hparams.n_vocab]
presents = lm_output['present']
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
return {
'logits': logits,
'presents': presents,
}
with tf.name_scope('sample_sequence'):
# Don't feed the last context token -- leave that to the loop below
# TODO: Would be slightly faster if we called step on the entire context,
# rather than leaving the last token transformer calculation to the while loop.
context_output = step(hparams, context[:, :-1])
def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]),
tf.concat([output, samples], axis=1),
]
def cond(*args):
return True
_, _, tokens = tf.while_loop(
cond=cond, body=body,
maximum_iterations=length,
loop_vars=[
context_output['presents'],
context[:, -1],
context,
],
shape_invariants=[
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]),
tf.TensorShape([batch_size, None]),
],
back_prop=False,
)
return tokens