From c0859d7523841035798ba8200b1a3cdf069138ab Mon Sep 17 00:00:00 2001 From: Albert Wu Date: Thu, 30 May 2019 21:49:18 -0700 Subject: [PATCH] Fix TODO in sample.sample_sequences- Avoid 'leaving last token calculation to while loop' (#119) * do initial run on full context * decrement while loop iterations * add context to output * remove first param * removing first param: change shape invariant --- src/sample.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/sample.py b/src/sample.py index c309ef0..6649531 100644 --- a/src/sample.py +++ b/src/sample.py @@ -41,36 +41,33 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte } with tf.name_scope('sample_sequence'): - # Don't feed the last context token -- leave that to the loop below - # TODO: Would be slightly faster if we called step on the entire context, - # rather than leaving the last token transformer calculation to the while loop. - context_output = step(hparams, context[:, :-1]) - def body(past, prev, output): - next_outputs = step(hparams, prev[:, tf.newaxis], past=past) + next_outputs = step(hparams, prev, past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ - tf.concat([past, next_outputs['presents']], axis=-2), - tf.squeeze(samples, axis=[1]), - tf.concat([output, samples], axis=1), + next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2), + samples, + tf.concat([output, samples], axis=1) ] + past, prev, output = body(None, context, context) + def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, - maximum_iterations=length, + maximum_iterations=length - 1, loop_vars=[ - context_output['presents'], - context[:, -1], - context, + past, + prev, + output ], shape_invariants=[ tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)), - tf.TensorShape([batch_size]), + tf.TensorShape([batch_size, None]), tf.TensorShape([batch_size, None]), ], back_prop=False,