added comments on registry usage, fixed typos in deepq and trpo_mpi registration
This commit is contained in:
@@ -93,7 +93,7 @@ def load_act(path):
|
||||
return ActWrapper.load_act(path)
|
||||
|
||||
|
||||
@registry.register('deepq', supports_vecenvs=False, defaults=defaults)
|
||||
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
|
||||
def learn(env,
|
||||
network,
|
||||
seed=None,
|
||||
|
@@ -1,7 +1,24 @@
|
||||
from baselines import logger
|
||||
# Registry of algorithms that keeps track of algorithms supported environments and
|
||||
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
|
||||
#
|
||||
# Example usage:
|
||||
#
|
||||
# from baselines import registry
|
||||
#
|
||||
# @registry.register('fancy_algorithm', supports_vecenv=False)
|
||||
# def learn(env, network):
|
||||
# return
|
||||
#
|
||||
# for algo_name, algo_entry in registry.registry.items():
|
||||
# if not algo_entry['supports_vecenv']:
|
||||
# print(f'{algo_name} does not support vecenvs')
|
||||
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"from baselines import logger
|
||||
|
||||
|
||||
|
||||
registry = {}
|
||||
|
||||
def register(name, supports_vecenv=True, defaults={}, **kwargs):
|
||||
def register(name, supports_vecenv=True, defaults={}):
|
||||
def get_fn_entrypoint(fn):
|
||||
import inspect
|
||||
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
|
||||
@@ -16,7 +33,6 @@ def register(name, supports_vecenv=True, defaults={}, **kwargs):
|
||||
fn = learn_fn,
|
||||
supports_vecenv=supports_vecenv,
|
||||
defaults=defaults,
|
||||
**kwargs
|
||||
)
|
||||
return learn_fn
|
||||
return _thunk
|
||||
|
@@ -84,7 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
|
||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||
|
||||
@registry.register('trpo_mpi', supports_vecenvs=False, defaults=defaults)
|
||||
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
|
||||
def learn(*,
|
||||
network,
|
||||
env,
|
||||
|
Reference in New Issue
Block a user