run.py can run algos from both baselines and rl_algs
This commit is contained in:
@@ -144,18 +144,27 @@ def get_default_network(env_type):
|
||||
|
||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||
|
||||
def get_alg_module(alg, submodule=None):
|
||||
submodule = submodule or alg
|
||||
try:
|
||||
# first try to import the alg module from baselines
|
||||
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
||||
except ImportError:
|
||||
# then from baselines
|
||||
alg_module = import_module('.'.join(['baselines', alg, submodule]))
|
||||
|
||||
return alg_module
|
||||
|
||||
|
||||
def get_learn_function(alg):
|
||||
alg_module = import_module('.'.join(['baselines', alg, alg]))
|
||||
return alg_module.learn
|
||||
return get_alg_module(alg).learn
|
||||
|
||||
def get_learn_function_defaults(alg, env_type):
|
||||
try:
|
||||
alg_defaults = import_module('.'.join(['baselines', alg, 'defaults']))
|
||||
alg_defaults = get_alg_module(alg, 'defaults')
|
||||
kwargs = getattr(alg_defaults, env_type)()
|
||||
except (ImportError, AttributeError):
|
||||
kwargs = {}
|
||||
|
||||
return kwargs
|
||||
|
||||
def parse(v):
|
||||
|
Reference in New Issue
Block a user