added comments on registry usage, fixed typos in deepq and trpo_mpi registration

This commit is contained in:
Peter Zhokhov
2018-10-23 11:14:48 -07:00
parent a8c2e643dc
commit a52dcae856
3 changed files with 21 additions and 5 deletions

View File

@@ -93,7 +93,7 @@ def load_act(path):
return ActWrapper.load_act(path)
@registry.register('deepq', supports_vecenvs=False, defaults=defaults)
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
def learn(env,
network,
seed=None,

View File

@@ -1,7 +1,24 @@
from baselines import logger
# Registry of algorithms that keeps track of algorithms supported environments and
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
#
# Example usage:
#
# from baselines import registry
#
# @registry.register('fancy_algorithm', supports_vecenv=False)
# def learn(env, network):
# return
#
# for algo_name, algo_entry in registry.registry.items():
# if not algo_entry['supports_vecenv']:
# print(f'{algo_name} does not support vecenvs')
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"from baselines import logger
registry = {}
def register(name, supports_vecenv=True, defaults={}, **kwargs):
def register(name, supports_vecenv=True, defaults={}):
def get_fn_entrypoint(fn):
import inspect
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
@@ -16,7 +33,6 @@ def register(name, supports_vecenv=True, defaults={}, **kwargs):
fn = learn_fn,
supports_vecenv=supports_vecenv,
defaults=defaults,
**kwargs
)
return learn_fn
return _thunk

View File

@@ -84,7 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
seg["tdlamret"] = seg["adv"] + seg["vpred"]
@registry.register('trpo_mpi', supports_vecenvs=False, defaults=defaults)
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
def learn(*,
network,
env,