From ac5d52295f8a1c3856ea24fb239087cc1a3d1131 Mon Sep 17 00:00:00 2001 From: Jeff Wu Date: Mon, 26 Aug 2019 21:20:33 -0700 Subject: [PATCH] nucleus sampling --- src/generate_unconditional_samples.py | 3 ++- src/interactive_conditional_samples.py | 3 ++- src/sample.py | 21 ++++++++++++++++++++- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py index cc3f3a3..eaf9a63 100755 --- a/src/generate_unconditional_samples.py +++ b/src/generate_unconditional_samples.py @@ -16,6 +16,7 @@ def sample_model( length=None, temperature=1, top_k=0, + top_p=1, models_dir='models', ): """ @@ -58,7 +59,7 @@ def sample_model( hparams=hparams, length=length, start_token=enc.encoder['<|endoftext|>'], batch_size=batch_size, - temperature=temperature, top_k=top_k + temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] saver = tf.train.Saver() diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py index 48b5cb3..8b66000 100755 --- a/src/interactive_conditional_samples.py +++ b/src/interactive_conditional_samples.py @@ -16,6 +16,7 @@ def interact_model( length=None, temperature=1, top_k=0, + top_p=1, models_dir='models', ): """ @@ -61,7 +62,7 @@ def interact_model( hparams=hparams, length=length, context=context, batch_size=batch_size, - temperature=temperature, top_k=top_k + temperature=temperature, top_k=top_k, top_p=top_p ) saver = tf.train.Saver() diff --git a/src/sample.py b/src/sample.py index 6649531..c90ed28 100644 --- a/src/sample.py +++ b/src/sample.py @@ -22,7 +22,25 @@ def top_k_logits(logits, k): ) -def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): +def top_p_logits(logits, p): + """Nucleus sampling""" + batch, _ = logits.shape.as_list() + sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) + cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) + indices = tf.stack([ + tf.range(0, batch), + # number of indices to include + tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0), + ], axis=-1) + min_values = tf.gather_nd(sorted_logits, indices) + return tf.where( + logits < min_values, + tf.ones_like(logits) * -1e10, + logits, + ) + + +def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: @@ -45,6 +63,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte next_outputs = step(hparams, prev, past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) + logits = top_p_logits(logits, p=top_p) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),