nucleus sampling
This commit is contained in:
@ -16,6 +16,7 @@ def sample_model(
|
|||||||
length=None,
|
length=None,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
|
top_p=1,
|
||||||
models_dir='models',
|
models_dir='models',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -58,7 +59,7 @@ def sample_model(
|
|||||||
hparams=hparams, length=length,
|
hparams=hparams, length=length,
|
||||||
start_token=enc.encoder['<|endoftext|>'],
|
start_token=enc.encoder['<|endoftext|>'],
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
temperature=temperature, top_k=top_k
|
temperature=temperature, top_k=top_k, top_p=top_p
|
||||||
)[:, 1:]
|
)[:, 1:]
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
|
@ -16,6 +16,7 @@ def interact_model(
|
|||||||
length=None,
|
length=None,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
|
top_p=1,
|
||||||
models_dir='models',
|
models_dir='models',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -61,7 +62,7 @@ def interact_model(
|
|||||||
hparams=hparams, length=length,
|
hparams=hparams, length=length,
|
||||||
context=context,
|
context=context,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
temperature=temperature, top_k=top_k
|
temperature=temperature, top_k=top_k, top_p=top_p
|
||||||
)
|
)
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
|
@ -22,7 +22,25 @@ def top_k_logits(logits, k):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
|
def top_p_logits(logits, p):
|
||||||
|
"""Nucleus sampling"""
|
||||||
|
batch, _ = logits.shape.as_list()
|
||||||
|
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
|
||||||
|
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
|
||||||
|
indices = tf.stack([
|
||||||
|
tf.range(0, batch),
|
||||||
|
# number of indices to include
|
||||||
|
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
|
||||||
|
], axis=-1)
|
||||||
|
min_values = tf.gather_nd(sorted_logits, indices)
|
||||||
|
return tf.where(
|
||||||
|
logits < min_values,
|
||||||
|
tf.ones_like(logits) * -1e10,
|
||||||
|
logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
|
||||||
if start_token is None:
|
if start_token is None:
|
||||||
assert context is not None, 'Specify exactly one of start_token and context!'
|
assert context is not None, 'Specify exactly one of start_token and context!'
|
||||||
else:
|
else:
|
||||||
@ -45,6 +63,7 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte
|
|||||||
next_outputs = step(hparams, prev, past=past)
|
next_outputs = step(hparams, prev, past=past)
|
||||||
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
|
||||||
logits = top_k_logits(logits, k=top_k)
|
logits = top_k_logits(logits, k=top_k)
|
||||||
|
logits = top_p_logits(logits, p=top_p)
|
||||||
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
|
||||||
return [
|
return [
|
||||||
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
|
||||||
|
Reference in New Issue
Block a user