diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index 4108210..e7ab033 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -93,7 +93,7 @@ def load_act(path): return ActWrapper.load_act(path) -@registry.register('deepq', supports_vecenvs=False, defaults=defaults) +@registry.register('deepq', supports_vecenv=False, defaults=defaults) def learn(env, network, seed=None, diff --git a/baselines/registry.py b/baselines/registry.py index 9b0c96c..cb5b892 100644 --- a/baselines/registry.py +++ b/baselines/registry.py @@ -1,7 +1,24 @@ -from baselines import logger +# Registry of algorithms that keeps track of algorithms supported environments and +# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc) +# +# Example usage: +# +# from baselines import registry +# +# @registry.register('fancy_algorithm', supports_vecenv=False) +# def learn(env, network): +# return +# +# for algo_name, algo_entry in registry.registry.items(): +# if not algo_entry['supports_vecenv']: +# print(f'{algo_name} does not support vecenvs') +# # should print "fancy_algorithm does not support vecenvs" (among other ones)"from baselines import logger + + + registry = {} -def register(name, supports_vecenv=True, defaults={}, **kwargs): +def register(name, supports_vecenv=True, defaults={}): def get_fn_entrypoint(fn): import inspect return '.'.join([inspect.getmodule(fn).__name__, fn.__name__]) @@ -16,7 +33,6 @@ def register(name, supports_vecenv=True, defaults={}, **kwargs): fn = learn_fn, supports_vecenv=supports_vecenv, defaults=defaults, - **kwargs ) return learn_fn return _thunk diff --git a/baselines/trpo_mpi/trpo_mpi.py b/baselines/trpo_mpi/trpo_mpi.py index a9fc8a1..66374f2 100644 --- a/baselines/trpo_mpi/trpo_mpi.py +++ b/baselines/trpo_mpi/trpo_mpi.py @@ -84,7 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam): gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam seg["tdlamret"] = seg["adv"] + seg["vpred"] -@registry.register('trpo_mpi', supports_vecenvs=False, defaults=defaults) +@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults) def learn(*, network, env,