added comments on registry usage, fixed typos in deepq and trpo_mpi registration
This commit is contained in:
@@ -93,7 +93,7 @@ def load_act(path):
|
|||||||
return ActWrapper.load_act(path)
|
return ActWrapper.load_act(path)
|
||||||
|
|
||||||
|
|
||||||
@registry.register('deepq', supports_vecenvs=False, defaults=defaults)
|
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
|
||||||
def learn(env,
|
def learn(env,
|
||||||
network,
|
network,
|
||||||
seed=None,
|
seed=None,
|
||||||
|
@@ -1,7 +1,24 @@
|
|||||||
from baselines import logger
|
# Registry of algorithms that keeps track of algorithms supported environments and
|
||||||
|
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
|
||||||
|
#
|
||||||
|
# Example usage:
|
||||||
|
#
|
||||||
|
# from baselines import registry
|
||||||
|
#
|
||||||
|
# @registry.register('fancy_algorithm', supports_vecenv=False)
|
||||||
|
# def learn(env, network):
|
||||||
|
# return
|
||||||
|
#
|
||||||
|
# for algo_name, algo_entry in registry.registry.items():
|
||||||
|
# if not algo_entry['supports_vecenv']:
|
||||||
|
# print(f'{algo_name} does not support vecenvs')
|
||||||
|
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"from baselines import logger
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
registry = {}
|
registry = {}
|
||||||
|
|
||||||
def register(name, supports_vecenv=True, defaults={}, **kwargs):
|
def register(name, supports_vecenv=True, defaults={}):
|
||||||
def get_fn_entrypoint(fn):
|
def get_fn_entrypoint(fn):
|
||||||
import inspect
|
import inspect
|
||||||
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
|
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
|
||||||
@@ -16,7 +33,6 @@ def register(name, supports_vecenv=True, defaults={}, **kwargs):
|
|||||||
fn = learn_fn,
|
fn = learn_fn,
|
||||||
supports_vecenv=supports_vecenv,
|
supports_vecenv=supports_vecenv,
|
||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
**kwargs
|
|
||||||
)
|
)
|
||||||
return learn_fn
|
return learn_fn
|
||||||
return _thunk
|
return _thunk
|
||||||
|
@@ -84,7 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
|
|||||||
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
|
||||||
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
seg["tdlamret"] = seg["adv"] + seg["vpred"]
|
||||||
|
|
||||||
@registry.register('trpo_mpi', supports_vecenvs=False, defaults=defaults)
|
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
|
||||||
def learn(*,
|
def learn(*,
|
||||||
network,
|
network,
|
||||||
env,
|
env,
|
||||||
|
Reference in New Issue
Block a user