Fix TODO in sample.sample_sequences- Avoid 'leaving last token calculation to while loop' (#119)

* do initial run on full context

* decrement while loop iterations

* add context to output

* remove first param

* removing first param: change shape invariant
This commit is contained in:
Albert Wu
2019-05-30 21:49:18 -07:00
committed by Jeff Wu
parent e5c5054474
commit c0859d7523

View File

@ -41,36 +41,33 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
} }
with tf.name_scope('sample_sequence'): with tf.name_scope('sample_sequence'):
# Don't feed the last context token -- leave that to the loop below
# TODO: Would be slightly faster if we called step on the entire context,
# rather than leaving the last token transformer calculation to the while loop.
context_output = step(hparams, context[:, :-1])
def body(past, prev, output): def body(past, prev, output):
next_outputs = step(hparams, prev[:, tf.newaxis], past=past) next_outputs = step(hparams, prev, past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k) logits = top_k_logits(logits, k=top_k)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [ return [
tf.concat([past, next_outputs['presents']], axis=-2), next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
tf.squeeze(samples, axis=[1]), samples,
tf.concat([output, samples], axis=1), tf.concat([output, samples], axis=1)
] ]
past, prev, output = body(None, context, context)
def cond(*args): def cond(*args):
return True return True
_, _, tokens = tf.while_loop( _, _, tokens = tf.while_loop(
cond=cond, body=body, cond=cond, body=body,
maximum_iterations=length, maximum_iterations=length - 1,
loop_vars=[ loop_vars=[
context_output['presents'], past,
context[:, -1], prev,
context, output
], ],
shape_invariants=[ shape_invariants=[
tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]),
tf.TensorShape([batch_size, None]), tf.TensorShape([batch_size, None]),
], ],
back_prop=False, back_prop=False,