allow models to be in a separate folder via models_dir argument (#129)

* models_dir argument to allow models in a separate folder

* default value for models_dir to be same as before

* allow environment variables and user home in models_dir
This commit is contained in:
Memo Akten
2019-05-16 19:42:58 +03:00
committed by Jeff Wu
parent dd75299dfe
commit e5c5054474
3 changed files with 17 additions and 9 deletions

View File

@ -105,10 +105,10 @@ class Encoder:
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def get_encoder(model_name):
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
def get_encoder(model_name, models_dir):
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
encoder = json.load(f)
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder(

View File

@ -16,6 +16,7 @@ def sample_model(
length=None,
temperature=1,
top_k=0,
models_dir='models',
):
"""
Run the sample_model
@ -35,10 +36,13 @@ def sample_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
"""
enc = encoder.get_encoder(model_name)
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
@ -58,7 +62,7 @@ def sample_model(
)[:, 1:]
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)
generated = 0

View File

@ -16,6 +16,7 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
models_dir='models',
):
"""
Interactively run the model
@ -34,14 +35,17 @@ def interact_model(
considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
"""
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if batch_size is None:
batch_size = 1
assert nsamples % batch_size == 0
enc = encoder.get_encoder(model_name)
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f:
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
@ -61,7 +65,7 @@ def interact_model(
)
saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt)
while True: