Fix TODO in sample.sample_sequences- Avoid 'leaving last token calculation to while loop' (#119)
* do initial run on full context * decrement while loop iterations * add context to output * remove first param * removing first param: change shape invariant
This commit is contained in:
@ -41,36 +41,33 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
|
|||||||
}
|
}
|
||||||
|
|
||||||
with tf.name_scope('sample_sequence'):
|
with tf.name_scope('sample_sequence'):
|
||||||
# Don't feed the last context token -- leave that to the loop below
|
|
||||||
# TODO: Would be slightly faster if we called step on the entire context,
|
|
||||||
# rather than leaving the last token transformer calculation to the while loop.
|
|
||||||
context_output = step(hparams, context[:, :-1])
|
|
||||||
|
|
||||||
def body(past, prev, output):
|
def body(past, prev, output):
|
||||||
next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
|
next_outputs = step(hparams, prev, past=past)
|
||||||
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
||||||
logits = top_k_logits(logits, k=top_k)
|
logits = top_k_logits(logits, k=top_k)
|
||||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||||
return [
|
return [
|
||||||
tf.concat([past, next_outputs['presents']], axis=-2),
|
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
||||||
tf.squeeze(samples, axis=[1]),
|
samples,
|
||||||
tf.concat([output, samples], axis=1),
|
tf.concat([output, samples], axis=1)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
past, prev, output = body(None, context, context)
|
||||||
|
|
||||||
def cond(*args):
|
def cond(*args):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
_, _, tokens = tf.while_loop(
|
_, _, tokens = tf.while_loop(
|
||||||
cond=cond, body=body,
|
cond=cond, body=body,
|
||||||
maximum_iterations=length,
|
maximum_iterations=length - 1,
|
||||||
loop_vars=[
|
loop_vars=[
|
||||||
context_output['presents'],
|
past,
|
||||||
context[:, -1],
|
prev,
|
||||||
context,
|
output
|
||||||
],
|
],
|
||||||
shape_invariants=[
|
shape_invariants=[
|
||||||
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
|
||||||
tf.TensorShape([batch_size]),
|
tf.TensorShape([batch_size, None]),
|
||||||
tf.TensorShape([batch_size, None]),
|
tf.TensorShape([batch_size, None]),
|
||||||
],
|
],
|
||||||
back_prop=False,
|
back_prop=False,
|
||||||
|
Reference in New Issue
Block a user