Compare commits

...

21 Commits

Author SHA1 Message Date
Peter Zhokhov
c7a0c2781a autopep8 and import fix 2018-10-23 13:42:44 -07:00
Peter Zhokhov
06c2fd2a3c typo in registry.py 2018-10-23 11:16:20 -07:00
Peter Zhokhov
a52dcae856 added comments on registry usage, fixed typos in deepq and trpo_mpi registration 2018-10-23 11:14:48 -07:00
Peter Zhokhov
a8c2e643dc import error in run.py 2018-10-23 10:11:48 -07:00
Peter Zhokhov
5ca31a7c25 merged latest master 2018-10-23 10:07:51 -07:00
Peter Zhokhov
35dcb6fd74 merged internal 2018-10-22 19:22:46 -07:00
Peter Zhokhov
c1c7c469a1 fix syntax 2018-10-22 19:19:54 -07:00
Peter Zhokhov
b4869bd271 use algorithm registry - staging for internal benchmarks 2018-10-22 19:13:10 -07:00
Peter Zhokhov
29cfb4a69c Merge branch 'internal' of github.com:openai/baselines into peterz_learn_registration 2018-10-22 19:08:27 -07:00
Peter Zhokhov
bd7c479e04 merge master 2018-10-22 19:07:46 -07:00
Peter Zhokhov
3ddf69c4b5 defaults are handled through registry 2018-10-22 18:10:10 -07:00
Peter Zhokhov
bfdc552521 moving things around 2018-10-22 17:45:55 -07:00
Peter Zhokhov
bcb4d4f795 moved imports back to run 2018-10-22 17:15:24 -07:00
Peter Zhokhov
0c9b236475 using registry of algorithms 2018-10-22 17:01:49 -07:00
Peter Zhokhov
01884bb0eb wrap retro envs correctly for other (non-deepq) algorithms 2018-10-22 14:21:26 -07:00
Peter Zhokhov
ade2d61be7 Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch 2018-10-19 17:27:57 -07:00
Peter Zhokhov
f6ef52a9df Merge branch 'master' of github.com:openai/baselines into internal 2018-10-19 09:52:23 -07:00
Peter Zhokhov
8964d5ad45 flake8 and numpy.random.random_integers deprecation warning 2018-10-16 14:58:23 -07:00
Peter Zhokhov
8624bc629c eval_done[d]==True -> eval_done[d] 2018-10-15 18:31:55 -07:00
Peter Zhokhov
7b33af0395 B -> nenvs for consistency with other algos, small cleanups 2018-10-15 18:29:48 -07:00
Peter Zhokhov
4bca9158a1 sync internal changes. Make ddpg work with vecenvs 2018-10-15 17:40:24 -07:00
16 changed files with 130 additions and 68 deletions

View File

@@ -0,0 +1,12 @@
# explicitly import sub-packages to register algorithms
import baselines.a2c.a2c
import baselines.acer.acer
import baselines.acktr.acktr
import baselines.deepq.deepq
import baselines.ddpg.ddpg
import baselines.ppo2.ppo2
# not really sure why flake8 complains only about trpo_mpi here...
import baselines.trpo_mpi.trpo_mpi # noqa: F401

View File

@@ -2,13 +2,12 @@ import time
import functools import functools
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util from baselines.common import tf_util
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.a2c.utils import Scheduler, find_trainable_variables from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.runner import Runner from baselines.a2c.runner import Runner
@@ -114,6 +113,7 @@ class Model(object):
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
@registry.register('a2c')
def learn( def learn(
network, network,
env, env,

View File

@@ -2,7 +2,7 @@ import time
import functools import functools
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
@@ -16,6 +16,7 @@ from baselines.a2c.utils import EpisodeStats
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
from baselines.acer.buffer import Buffer from baselines.acer.buffer import Buffer
from baselines.acer.runner import Runner from baselines.acer.runner import Runner
from baselines.acer.defaults import defaults
# remove last step # remove last step
def strip(var, nenvs, nsteps, flat = False): def strip(var, nenvs, nsteps, flat = False):
@@ -270,7 +271,7 @@ class Acer():
logger.record_tabular(name, float(val)) logger.record_tabular(name, float(val))
logger.dump_tabular() logger.dump_tabular()
@registry.register('acer', defaults=defaults)
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01, def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99, max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0, log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,

View File

@@ -1,4 +1,3 @@
def atari(): defaults = {
return dict( 'atari': dict(lrschedule='constant')
lrschedule='constant' }
)

View File

@@ -2,7 +2,7 @@ import os.path as osp
import time import time
import functools import functools
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance from baselines.common import set_global_seeds, explained_variance
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
@@ -11,6 +11,7 @@ from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.a2c.runner import Runner from baselines.a2c.runner import Runner
from baselines.a2c.utils import Scheduler, find_trainable_variables from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.acktr import kfac from baselines.acktr import kfac
from baselines.acktr.defaults import defaults
class Model(object): class Model(object):
@@ -90,6 +91,7 @@ class Model(object):
self.initial_state = step_model.initial_state self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
@registry.register('acktr', defaults=defaults)
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs): kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):

View File

@@ -1,5 +1,6 @@
def mujoco(): defaults = {
return dict( 'mujoco' : dict(
nsteps=2500, nsteps=2500,
value_network='copy' value_network='copy'
) )
}

View File

@@ -16,11 +16,13 @@ from baselines.common import set_global_seeds
from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common import retro_wrappers from baselines.common import retro_wrappers
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None): def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None, frame_stack_size=1):
""" """
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. Create a wrapped, monitored SubprocVecEnv
""" """
if wrapper_kwargs is None: wrapper_kwargs = {} if wrapper_kwargs is None: wrapper_kwargs = {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
@@ -38,12 +40,18 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
set_global_seeds(seed) set_global_seeds(seed)
if num_env > 1: if num_env > 1:
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)]) venv = SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
else: else:
return DummyVecEnv([make_thunk(start_index)]) venv = DummyVecEnv([make_thunk(start_index)])
if frame_stack_size > 1:
venv = VecFrameStack(venv, frame_stack_size)
return venv
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None):
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if env_type == 'atari': if env_type == 'atari':
env = make_atari(env_id) env = make_atari(env_id)

View File

@@ -10,11 +10,11 @@ from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, Orns
import baselines.common.tf_util as U import baselines.common.tf_util as U
from baselines import logger from baselines import logger, registry
import numpy as np import numpy as np
from mpi4py import MPI from mpi4py import MPI
@registry.register('ddpg')
def learn(network, env, def learn(network, env,
seed=None, seed=None,
total_timesteps=None, total_timesteps=None,

View File

@@ -8,7 +8,7 @@ import numpy as np
import baselines.common.tf_util as U import baselines.common.tf_util as U
from baselines.common.tf_util import load_variables, save_variables from baselines.common.tf_util import load_variables, save_variables
from baselines import logger from baselines import logger, registry
from baselines.common.schedules import LinearSchedule from baselines.common.schedules import LinearSchedule
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -18,6 +18,7 @@ from baselines.deepq.utils import ObservationInput
from baselines.common.tf_util import get_session from baselines.common.tf_util import get_session
from baselines.deepq.models import build_q_func from baselines.deepq.models import build_q_func
from baselines.deepq.defaults import defaults
class ActWrapper(object): class ActWrapper(object):
@@ -92,6 +93,7 @@ def load_act(path):
return ActWrapper.load_act(path) return ActWrapper.load_act(path)
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
def learn(env, def learn(env,
network, network,
seed=None, seed=None,

View File

@@ -16,6 +16,8 @@ def atari():
dueling=True dueling=True
) )
def retro():
return atari()
defaults = {
'atari': atari(),
'retro': atari()
}

View File

@@ -1,5 +1,5 @@
def mujoco(): defaults = {
return dict( 'mujoco': dict(
nsteps=2048, nsteps=2048,
nminibatches=32, nminibatches=32,
lam=0.95, lam=0.95,
@@ -10,13 +10,13 @@ def mujoco():
lr=lambda f: 3e-4 * f, lr=lambda f: 3e-4 * f,
cliprange=0.2, cliprange=0.2,
value_network='copy' value_network='copy'
) ),
def atari(): 'atari': dict(
return dict(
nsteps=128, nminibatches=4, nsteps=128, nminibatches=4,
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1, lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
ent_coef=.01, ent_coef=.01,
lr=lambda f : f * 2.5e-4, lr=lambda f : f * 2.5e-4,
cliprange=lambda f : f * 0.1, cliprange=lambda f : f * 0.1,
) )
}

View File

@@ -4,7 +4,7 @@ import functools
import numpy as np import numpy as np
import os.path as osp import os.path as osp
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from collections import deque from collections import deque
from baselines.common import explained_variance, set_global_seeds from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
@@ -15,6 +15,7 @@ from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from mpi4py import MPI from mpi4py import MPI
from baselines.common.tf_util import initialize from baselines.common.tf_util import initialize
from baselines.common.mpi_util import sync_from_root from baselines.common.mpi_util import sync_from_root
from baselines.ppo2.defaults import defaults
class Model(object): class Model(object):
""" """
@@ -218,6 +219,7 @@ def constfn(val):
return val return val
return f return f
@registry.register('ppo2', defaults=defaults)
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,

39
baselines/registry.py Normal file
View File

@@ -0,0 +1,39 @@
# Registry of algorithms that keeps track of algorithms supported environments and
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
#
# Example usage:
#
# from baselines import registry
#
# @registry.register('fancy_algorithm', supports_vecenv=False)
# def learn(env, network):
# return
#
# for algo_name, algo_entry in registry.registry.items():
# if not algo_entry['supports_vecenv']:
# print(f'{algo_name} does not support vecenvs')
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"f
from baselines import logger
registry = {}
def register(name, supports_vecenv=True, defaults={}):
def get_fn_entrypoint(fn):
import inspect
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
def _thunk(learn_fn):
old_entry = registry.get(name)
if old_entry is not None:
logger.warn('Re-registering learn function {} (old entrypoint {}, new entrypoint {}) '.format(
name, get_fn_entrypoint(old_entry['fn']), get_fn_entrypoint(learn_fn)))
registry[name] = dict(
fn = learn_fn,
supports_vecenv=supports_vecenv,
defaults=defaults,
)
return learn_fn
return _thunk

View File

@@ -3,16 +3,12 @@ import multiprocessing
import os.path as osp import os.path as osp
import gym import gym
from collections import defaultdict from collections import defaultdict
import tensorflow as tf
import numpy as np import numpy as np
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session
from baselines import logger
from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines import logger
from baselines.registry import registry
try: try:
from mpi4py import MPI from mpi4py import MPI
@@ -89,28 +85,20 @@ def build_env(args):
seed = args.seed seed = args.seed
env_type, env_id = get_env_type(args.env) env_type, env_id = get_env_type(args.env)
assert alg in registry, 'Unknown algorithm {}'.format(alg)
if env_type in {'atari', 'retro'}: if env_type in {'atari', 'retro'}:
if alg == 'deepq': frame_stack_size = 4
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
elif alg == 'trpo_mpi':
env = make_env(env_id, env_type, seed=seed)
else:
frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
env = VecFrameStack(env, frame_stack_size)
else: else:
config = tf.ConfigProto(allow_soft_placement=True, frame_stack_size = 1
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale) if registry[alg]['supports_vecenv']:
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale, frame_stack_size=frame_stack_size)
else:
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': frame_stack_size > 1})
if env_type == 'mujoco': if env_type == 'mujoco' and registry[alg]['supports_vecenv']:
env = VecNormalize(env) env = VecNormalize(env)
return env return env
@@ -137,29 +125,26 @@ def get_default_network(env_type):
return 'mlp' return 'mlp'
def get_alg_module(alg, submodule=None): def get_alg_module(alg, submodule=None):
submodule = submodule or alg import inspect
try: entry = registry.get(alg)
# first try to import the alg module from baselines assert entry is not None, 'Unregistered algorithm {}'.format(alg)
alg_module = import_module('.'.join(['baselines', alg, submodule])) module = inspect.getmodule(entry['fn']).__name__
except ImportError: if submodule is not None:
# then from rl_algs module = '.'.join([module, submodule])
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule])) return module
return alg_module
def get_learn_function(alg): def get_learn_function(alg):
return get_alg_module(alg).learn entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['fn']
def get_learn_function_defaults(alg, env_type): def get_learn_function_defaults(alg, env_type):
try: entry = registry.get(alg)
alg_defaults = get_alg_module(alg, 'defaults') assert entry is not None, 'Unregistered algorithm {}'.format(alg)
kwargs = getattr(alg_defaults, env_type)() return entry['defaults'].get(env_type, {})
except (ImportError, AttributeError):
kwargs = {}
return kwargs
def parse_cmdline_kwargs(args): def parse_cmdline_kwargs(args):
@@ -193,6 +178,7 @@ def main():
rank = MPI.COMM_WORLD.Get_rank() rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args) model, env = train(args, extra_args)
env.close() env.close()
if args.save_path is not None and rank == 0: if args.save_path is not None and rank == 0:

View File

@@ -28,3 +28,8 @@ def mujoco():
vf_stepsize=1e-3, vf_stepsize=1e-3,
normalize_observations=True, normalize_observations=True,
) )
defaults = {
'atari': atari(),
'mujoco': mujoco(),
}

View File

@@ -1,5 +1,5 @@
from baselines.common import explained_variance, zipsame, dataset from baselines.common import explained_variance, zipsame, dataset
from baselines import logger from baselines import logger, registry
import baselines.common.tf_util as U import baselines.common.tf_util as U
import tensorflow as tf, numpy as np import tensorflow as tf, numpy as np
import time import time
@@ -13,6 +13,8 @@ from baselines.common.input import observation_placeholder
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from contextlib import contextmanager from contextlib import contextmanager
from baselines.trpo_mpi.defaults import defaults
def traj_segment_generator(pi, env, horizon, stochastic): def traj_segment_generator(pi, env, horizon, stochastic):
# Initialize state variables # Initialize state variables
t = 0 t = 0
@@ -82,6 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
seg["tdlamret"] = seg["adv"] + seg["vpred"] seg["tdlamret"] = seg["adv"] + seg["vpred"]
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
def learn(*, def learn(*,
network, network,
env, env,