allow models to be in a separate folder via models_dir argument (#129)
* models_dir argument to allow models in a separate folder * default value for models_dir to be same as before * allow environment variables and user home in models_dir
This commit is contained in:
@ -105,10 +105,10 @@ class Encoder:
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
|
||||
return text
|
||||
|
||||
def get_encoder(model_name):
|
||||
with open(os.path.join('models', model_name, 'encoder.json'), 'r') as f:
|
||||
def get_encoder(model_name, models_dir):
|
||||
with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f:
|
||||
encoder = json.load(f)
|
||||
with open(os.path.join('models', model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
|
||||
return Encoder(
|
||||
|
@ -16,6 +16,7 @@ def sample_model(
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Run the sample_model
|
||||
@ -35,10 +36,13 @@ def sample_model(
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
enc = encoder.get_encoder(model_name)
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
@ -58,7 +62,7 @@ def sample_model(
|
||||
)[:, 1:]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
generated = 0
|
||||
|
@ -16,6 +16,7 @@ def interact_model(
|
||||
length=None,
|
||||
temperature=1,
|
||||
top_k=0,
|
||||
models_dir='models',
|
||||
):
|
||||
"""
|
||||
Interactively run the model
|
||||
@ -34,14 +35,17 @@ def interact_model(
|
||||
considered for each step (token), resulting in deterministic completions,
|
||||
while 40 means 40 words are considered at each step. 0 (default) is a
|
||||
special setting meaning no restrictions. 40 generally is a good value.
|
||||
:models_dir : path to parent folder containing model subfolders
|
||||
(i.e. contains the <model_name> folder)
|
||||
"""
|
||||
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
|
||||
if batch_size is None:
|
||||
batch_size = 1
|
||||
assert nsamples % batch_size == 0
|
||||
|
||||
enc = encoder.get_encoder(model_name)
|
||||
enc = encoder.get_encoder(model_name, models_dir)
|
||||
hparams = model.default_hparams()
|
||||
with open(os.path.join('models', model_name, 'hparams.json')) as f:
|
||||
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
|
||||
hparams.override_from_dict(json.load(f))
|
||||
|
||||
if length is None:
|
||||
@ -61,7 +65,7 @@ def interact_model(
|
||||
)
|
||||
|
||||
saver = tf.train.Saver()
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
|
||||
ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))
|
||||
saver.restore(sess, ckpt)
|
||||
|
||||
while True:
|
||||
|
Reference in New Issue
Block a user