allow models to be in a separate folder via models_dir argument (#129)

* models_dir argument to allow models in a separate folder

* default value for models_dir to be same as before

* allow environment variables and user home in models_dir
This commit is contained in:
Memo Akten
2019-05-16 19:42:58 +03:00
committed by Jeff Wu
parent dd75299dfe
commit e5c5054474
3 changed files with 17 additions and 9 deletions

View File

@ -105,10 +105,10 @@ class Encoder:
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text return text
def get_encoder(model_name): def get_encoder(model_name, models_dir):
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f: with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
encoder = json.load(f) encoder = json.load(f)
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
bpe_data = f.read() bpe_data = f.read()
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
return Encoder( return Encoder(

View File

@ -16,6 +16,7 @@ def sample_model(
length=None, length=None,
temperature=1, temperature=1,
top_k=0, top_k=0,
models_dir='models',
): ):
""" """
Run the sample_model Run the sample_model
@ -35,10 +36,13 @@ def sample_model(
considered for each step (token), resulting in deterministic completions, considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value. special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
""" """
enc = encoder.get_encoder(model_name) models_dir = os.path.expanduser(os.path.expandvars(models_dir))
enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams() hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f: with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f)) hparams.override_from_dict(json.load(f))
if length is None: if length is None:
@ -58,7 +62,7 @@ def sample_model(
)[:, 1:] )[:, 1:]
saver = tf.train.Saver() saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt) saver.restore(sess, ckpt)
generated = 0 generated = 0

View File

@ -16,6 +16,7 @@ def interact_model(
length=None, length=None,
temperature=1, temperature=1,
top_k=0, top_k=0,
models_dir='models',
): ):
""" """
Interactively run the model Interactively run the model
@ -34,14 +35,17 @@ def interact_model(
considered for each step (token), resulting in deterministic completions, considered for each step (token), resulting in deterministic completions,
while 40 means 40 words are considered at each step. 0 (default) is a while 40 means 40 words are considered at each step. 0 (default) is a
special setting meaning no restrictions. 40 generally is a good value. special setting meaning no restrictions. 40 generally is a good value.
:models_dir : path to parent folder containing model subfolders
(i.e. contains the <model_name> folder)
""" """
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if batch_size is None: if batch_size is None:
batch_size = 1 batch_size = 1
assert nsamples % batch_size == 0 assert nsamples % batch_size == 0
enc = encoder.get_encoder(model_name) enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams() hparams = model.default_hparams()
with open(os.path.join('models', model_name, 'hparams.json')) as f: with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f)) hparams.override_from_dict(json.load(f))
if length is None: if length is None:
@ -61,7 +65,7 @@ def interact_model(
) )
saver = tf.train.Saver() saver = tf.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name)) ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
saver.restore(sess, ckpt) saver.restore(sess, ckpt)
while True: while True: