allow models to be in a separate folder via models_dir argument (#129)
* models_dir argument to allow models in a separate folder * default value for models_dir to be same as before * allow environment variables and user home in models_dir
This commit is contained in:
@ -105,10 +105,10 @@ class Encoder:
|
|||||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def get_encoder(model_name):
|
def get_encoder(model_name, models_dir):
|
||||||
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
|
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||||
encoder = json.load(f)
|
encoder = json.load(f)
|
||||||
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||||
bpe_data = f.read()
|
bpe_data = f.read()
|
||||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||||
return Encoder(
|
return Encoder(
|
||||||
|
@ -16,6 +16,7 @@ def sample_model(
|
|||||||
length=None,
|
length=None,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
|
models_dir='models',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run the sample_model
|
Run the sample_model
|
||||||
@ -35,10 +36,13 @@ def sample_model(
|
|||||||
considered for each step (token), resulting in deterministic completions,
|
considered for each step (token), resulting in deterministic completions,
|
||||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||||
special setting meaning no restrictions. 40 generally is a good value.
|
special setting meaning no restrictions. 40 generally is a good value.
|
||||||
|
:models_dir : path to parent folder containing model subfolders
|
||||||
|
(i.e. contains the <model_name> folder)
|
||||||
"""
|
"""
|
||||||
enc = encoder.get_encoder(model_name)
|
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||||
|
enc = encoder.get_encoder(model_name, models_dir)
|
||||||
hparams = model.default_hparams()
|
hparams = model.default_hparams()
|
||||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||||
hparams.override_from_dict(json.load(f))
|
hparams.override_from_dict(json.load(f))
|
||||||
|
|
||||||
if length is None:
|
if length is None:
|
||||||
@ -58,7 +62,7 @@ def sample_model(
|
|||||||
)[:, 1:]
|
)[:, 1:]
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||||
saver.restore(sess, ckpt)
|
saver.restore(sess, ckpt)
|
||||||
|
|
||||||
generated = 0
|
generated = 0
|
||||||
|
@ -16,6 +16,7 @@ def interact_model(
|
|||||||
length=None,
|
length=None,
|
||||||
temperature=1,
|
temperature=1,
|
||||||
top_k=0,
|
top_k=0,
|
||||||
|
models_dir='models',
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Interactively run the model
|
Interactively run the model
|
||||||
@ -34,14 +35,17 @@ def interact_model(
|
|||||||
considered for each step (token), resulting in deterministic completions,
|
considered for each step (token), resulting in deterministic completions,
|
||||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||||
special setting meaning no restrictions. 40 generally is a good value.
|
special setting meaning no restrictions. 40 generally is a good value.
|
||||||
|
:models_dir : path to parent folder containing model subfolders
|
||||||
|
(i.e. contains the <model_name> folder)
|
||||||
"""
|
"""
|
||||||
|
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
assert nsamples % batch_size == 0
|
assert nsamples % batch_size == 0
|
||||||
|
|
||||||
enc = encoder.get_encoder(model_name)
|
enc = encoder.get_encoder(model_name, models_dir)
|
||||||
hparams = model.default_hparams()
|
hparams = model.default_hparams()
|
||||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||||
hparams.override_from_dict(json.load(f))
|
hparams.override_from_dict(json.load(f))
|
||||||
|
|
||||||
if length is None:
|
if length is None:
|
||||||
@ -61,7 +65,7 @@ def interact_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
saver = tf.train.Saver()
|
saver = tf.train.Saver()
|
||||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||||
saver.restore(sess, ckpt)
|
saver.restore(sess, ckpt)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
Reference in New Issue
Block a user