2019-02-10 20:22:00 -08:00
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
import model
|
|
|
|
|
|
|
|
def top_k_logits(logits, k):
|
|
|
|
if k == 0:
|
|
|
|
# no truncation
|
|
|
|
return logits
|
|
|
|
|
|
|
|
def _top_k():
|
|
|
|
values, _ = tf.nn.top_k(logits, k=k)
|
|
|
|
min_values = values[:, -1, tf.newaxis]
|
|
|
|
return tf.where(
|
|
|
|
logits < min_values,
|
|
|
|
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
|
|
|
|
logits,
|
|
|
|
)
|
|
|
|
return tf.cond(
|
|
|
|
tf.equal(k, 0),
|
|
|
|
lambda: logits,
|
|
|
|
lambda: _top_k(),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2019-08-26 21:20:33 -07:00
|
|
|
def top_p_logits(logits, p):
|
|
|
|
"""Nucleus sampling"""
|
|
|
|
batch, _ = logits.shape.as_list()
|
|
|
|
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
|
|
|
|
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
|
|
|
indices = tf.stack([
|
|
|
|
tf.range(0, batch),
|
|
|
|
# number of indices to include
|
|
|
|
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
|
|
|
|
], axis=-1)
|
|
|
|
min_values = tf.gather_nd(sorted_logits, indices)
|
|
|
|
return tf.where(
|
|
|
|
logits < min_values,
|
|
|
|
tf.ones_like(logits) * -1e10,
|
|
|
|
logits,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
|
2019-02-10 20:22:00 -08:00
|
|
|
if start_token is None:
|
|
|
|
assert context is not None, 'Specify exactly one of start_token and context!'
|
|
|
|
else:
|
|
|
|
assert context is None, 'Specify exactly one of start_token and context!'
|
|
|
|
context = tf.fill([batch_size, 1], start_token)
|
|
|
|
|
|
|
|
def step(hparams, tokens, past=None):
|
|
|
|
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
|
|
|
|
|
|
|
|
logits = lm_output['logits'][:, :, :hparams.n_vocab]
|
|
|
|
presents = lm_output['present']
|
|
|
|
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
|
|
|
|
return {
|
|
|
|
'logits': logits,
|
|
|
|
'presents': presents,
|
|
|
|
}
|
|
|
|
|
|
|
|
with tf.name_scope('sample_sequence'):
|
|
|
|
def body(past, prev, output):
|
2019-05-30 21:49:18 -07:00
|
|
|
next_outputs = step(hparams, prev, past=past)
|
2019-02-10 20:22:00 -08:00
|
|
|
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
|
|
|
logits = top_k_logits(logits, k=top_k)
|
2019-08-26 21:20:33 -07:00
|
|
|
logits = top_p_logits(logits, p=top_p)
|
2019-02-10 20:22:00 -08:00
|
|
|
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
|
|
|
return [
|
2019-05-30 21:49:18 -07:00
|
|
|
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
|
|
|
samples,
|
|
|
|
tf.concat([output, samples], axis=1)
|
2019-02-10 20:22:00 -08:00
|
|
|
]
|
|
|
|
|
2019-05-30 21:49:18 -07:00
|
|
|
past, prev, output = body(None, context, context)
|
|
|
|
|
2019-02-10 20:22:00 -08:00
|
|
|
def cond(*args):
|
|
|
|
return True
|
|
|
|
|
|
|
|
_, _, tokens = tf.while_loop(
|
|
|
|
cond=cond, body=body,
|
2019-05-30 21:49:18 -07:00
|
|
|
maximum_iterations=length - 1,
|
2019-02-10 20:22:00 -08:00
|
|
|
loop_vars=[
|
2019-05-30 21:49:18 -07:00
|
|
|
past,
|
|
|
|
prev,
|
|
|
|
output
|
2019-02-10 20:22:00 -08:00
|
|
|
],
|
|
|
|
shape_invariants=[
|
|
|
|
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
2019-05-30 21:49:18 -07:00
|
|
|
tf.TensorShape([batch_size, None]),
|
2019-02-10 20:22:00 -08:00
|
|
|
tf.TensorShape([batch_size, None]),
|
|
|
|
],
|
|
|
|
back_prop=False,
|
|
|
|
)
|
|
|
|
|
|
|
|
return tokens
|