run.py can run algos from both baselines and rl_algs

This commit is contained in:
Peter Zhokhov
2018-07-30 16:09:48 -07:00
parent efc6bffce3
commit e662dd6409

View File

@@ -144,18 +144,27 @@ def get_default_network(env_type):
raise ValueError('Unknown env_type {}'.format(env_type))
def get_alg_module(alg, submodule=None):
submodule = submodule or alg
try:
# first try to import the alg module from baselines
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
except ImportError:
# then from baselines
alg_module = import_module('.'.join(['baselines', alg, submodule]))
return alg_module
def get_learn_function(alg):
alg_module = import_module('.'.join(['baselines', alg, alg]))
return alg_module.learn
return get_alg_module(alg).learn
def get_learn_function_defaults(alg, env_type):
try:
alg_defaults = import_module('.'.join(['baselines', alg, 'defaults']))
alg_defaults = get_alg_module(alg, 'defaults')
kwargs = getattr(alg_defaults, env_type)()
except (ImportError, AttributeError):
kwargs = {}
return kwargs
def parse(v):