First commit
This commit is contained in:
79
src/sample.py
Normal file
79
src/sample.py
Normal file
@ -0,0 +1,79 @@
|
||||
import tensorflow as tf
|
||||
|
||||
import model
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
if k == 0:
|
||||
# no truncation
|
||||
return logits
|
||||
|
||||
def _top_k():
|
||||
values, _ = tf.nn.top_k(logits, k=k)
|
||||
min_values = values[:, -1, tf.newaxis]
|
||||
return tf.where(
|
||||
logits < min_values,
|
||||
tf.ones_like(logits, dtype=logits.dtype) * -1e10,
|
||||
logits,
|
||||
)
|
||||
return tf.cond(
|
||||
tf.equal(k, 0),
|
||||
lambda: logits,
|
||||
lambda: _top_k(),
|
||||
)
|
||||
|
||||
|
||||
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
|
||||
if start_token is None:
|
||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||
else:
|
||||
assert context is None, 'Specify exactly one of start_token and context!'
|
||||
context = tf.fill([batch_size, 1], start_token)
|
||||
|
||||
def step(hparams, tokens, past=None):
|
||||
lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)
|
||||
|
||||
logits = lm_output['logits'][:, :, :hparams.n_vocab]
|
||||
presents = lm_output['present']
|
||||
presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
|
||||
return {
|
||||
'logits': logits,
|
||||
'presents': presents,
|
||||
}
|
||||
|
||||
with tf.name_scope('sample_sequence'):
|
||||
# Don't feed the last context token -- leave that to the loop below
|
||||
# TODO: Would be slightly faster if we called step on the entire context,
|
||||
# rather than leaving the last token transformer calculation to the while loop.
|
||||
context_output = step(hparams, context[:, :-1])
|
||||
|
||||
def body(past, prev, output):
|
||||
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
|
||||
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
||||
logits = top_k_logits(logits, k=top_k)
|
||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||
return [
|
||||
tf.concat([past, next_outputs['presents']], axis=-2),
|
||||
tf.squeeze(samples, axis=[1]),
|
||||
tf.concat([output, samples], axis=1),
|
||||
]
|
||||
|
||||
def cond(*args):
|
||||
return True
|
||||
|
||||
_, _, tokens = tf.while_loop(
|
||||
cond=cond, body=body,
|
||||
maximum_iterations=length,
|
||||
loop_vars=[
|
||||
context_output['presents'],
|
||||
context[:, -1],
|
||||
context,
|
||||
],
|
||||
shape_invariants=[
|
||||
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
||||
tf.TensorShape([batch_size]),
|
||||
tf.TensorShape([batch_size, None]),
|
||||
],
|
||||
back_prop=False,
|
||||
)
|
||||
|
||||
return tokens
|
Reference in New Issue
Block a user