Compare commits

...

21 Commits

Author SHA1 Message Date
Peter Zhokhov
c7a0c2781a autopep8 and import fix 2018-10-23 13:42:44 -07:00
Peter Zhokhov
06c2fd2a3c typo in registry.py 2018-10-23 11:16:20 -07:00
Peter Zhokhov
a52dcae856 added comments on registry usage, fixed typos in deepq and trpo_mpi registration 2018-10-23 11:14:48 -07:00
Peter Zhokhov
a8c2e643dc import error in run.py 2018-10-23 10:11:48 -07:00
Peter Zhokhov
5ca31a7c25 merged latest master 2018-10-23 10:07:51 -07:00
Peter Zhokhov
35dcb6fd74 merged internal 2018-10-22 19:22:46 -07:00
Peter Zhokhov
c1c7c469a1 fix syntax 2018-10-22 19:19:54 -07:00
Peter Zhokhov
b4869bd271 use algorithm registry - staging for internal benchmarks 2018-10-22 19:13:10 -07:00
Peter Zhokhov
29cfb4a69c Merge branch 'internal' of github.com:openai/baselines into peterz_learn_registration 2018-10-22 19:08:27 -07:00
Peter Zhokhov
bd7c479e04 merge master 2018-10-22 19:07:46 -07:00
Peter Zhokhov
3ddf69c4b5 defaults are handled through registry 2018-10-22 18:10:10 -07:00
Peter Zhokhov
bfdc552521 moving things around 2018-10-22 17:45:55 -07:00
Peter Zhokhov
bcb4d4f795 moved imports back to run 2018-10-22 17:15:24 -07:00
Peter Zhokhov
0c9b236475 using registry of algorithms 2018-10-22 17:01:49 -07:00
Peter Zhokhov
01884bb0eb wrap retro envs correctly for other (non-deepq) algorithms 2018-10-22 14:21:26 -07:00
Peter Zhokhov
ade2d61be7 Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch 2018-10-19 17:27:57 -07:00
Peter Zhokhov
f6ef52a9df Merge branch 'master' of github.com:openai/baselines into internal 2018-10-19 09:52:23 -07:00
Peter Zhokhov
8964d5ad45 flake8 and numpy.random.random_integers deprecation warning 2018-10-16 14:58:23 -07:00
Peter Zhokhov
8624bc629c eval_done[d]==True -> eval_done[d] 2018-10-15 18:31:55 -07:00
Peter Zhokhov
7b33af0395 B -> nenvs for consistency with other algos, small cleanups 2018-10-15 18:29:48 -07:00
Peter Zhokhov
4bca9158a1 sync internal changes. Make ddpg work with vecenvs 2018-10-15 17:40:24 -07:00
16 changed files with 130 additions and 68 deletions

View File

@@ -0,0 +1,12 @@
# explicitly import sub-packages to register algorithms
import baselines.a2c.a2c
import baselines.acer.acer
import baselines.acktr.acktr
import baselines.deepq.deepq
import baselines.ddpg.ddpg
import baselines.ppo2.ppo2
# not really sure why flake8 complains only about trpo_mpi here...
import baselines.trpo_mpi.trpo_mpi # noqa: F401

View File

@@ -2,13 +2,12 @@ import time
import functools
import tensorflow as tf
from baselines import logger
from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util
from baselines.common.policies import build_policy
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.runner import Runner
@@ -114,6 +113,7 @@ class Model(object):
tf.global_variables_initializer().run(session=sess)
@registry.register('a2c')
def learn(
network,
env,

View File

@@ -2,7 +2,7 @@ import time
import functools
import numpy as np
import tensorflow as tf
from baselines import logger
from baselines import logger, registry
from baselines.common import set_global_seeds
from baselines.common.policies import build_policy
@@ -16,6 +16,7 @@ from baselines.a2c.utils import EpisodeStats
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
from baselines.acer.buffer import Buffer
from baselines.acer.runner import Runner
from baselines.acer.defaults import defaults
# remove last step
def strip(var, nenvs, nsteps, flat = False):
@@ -270,7 +271,7 @@ class Acer():
logger.record_tabular(name, float(val))
logger.dump_tabular()
@registry.register('acer', defaults=defaults)
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,

View File

@@ -1,4 +1,3 @@
def atari():
return dict(
lrschedule='constant'
)
defaults = {
'atari': dict(lrschedule='constant')
}

View File

@@ -2,7 +2,7 @@ import os.path as osp
import time
import functools
import tensorflow as tf
from baselines import logger
from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance
from baselines.common.policies import build_policy
@@ -11,6 +11,7 @@ from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.a2c.runner import Runner
from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.acktr import kfac
from baselines.acktr.defaults import defaults
class Model(object):
@@ -90,6 +91,7 @@ class Model(object):
self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess)
@registry.register('acktr', defaults=defaults)
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):

View File

@@ -1,5 +1,6 @@
def mujoco():
return dict(
defaults = {
'mujoco' : dict(
nsteps=2500,
value_network='copy'
)
}

View File

@@ -16,11 +16,13 @@ from baselines.common import set_global_seeds
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common import retro_wrappers
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None, frame_stack_size=1):
"""
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
Create a wrapped, monitored SubprocVecEnv
"""
if wrapper_kwargs is None: wrapper_kwargs = {}
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
@@ -38,12 +40,18 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
set_global_seeds(seed)
if num_env > 1:
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
venv = SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
else:
return DummyVecEnv([make_thunk(start_index)])
venv = DummyVecEnv([make_thunk(start_index)])
if frame_stack_size > 1:
venv = VecFrameStack(venv, frame_stack_size)
return venv
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None):
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if env_type == 'atari':
env = make_atari(env_id)

View File

@@ -10,11 +10,11 @@ from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, Orns
import baselines.common.tf_util as U
from baselines import logger
from baselines import logger, registry
import numpy as np
from mpi4py import MPI
@registry.register('ddpg')
def learn(network, env,
seed=None,
total_timesteps=None,

View File

@@ -8,7 +8,7 @@ import numpy as np
import baselines.common.tf_util as U
from baselines.common.tf_util import load_variables, save_variables
from baselines import logger
from baselines import logger, registry
from baselines.common.schedules import LinearSchedule
from baselines.common import set_global_seeds
@@ -18,6 +18,7 @@ from baselines.deepq.utils import ObservationInput
from baselines.common.tf_util import get_session
from baselines.deepq.models import build_q_func
from baselines.deepq.defaults import defaults
class ActWrapper(object):
@@ -92,6 +93,7 @@ def load_act(path):
return ActWrapper.load_act(path)
@registry.register('deepq', supports_vecenv=False, defaults=defaults)
def learn(env,
network,
seed=None,

View File

@@ -16,6 +16,8 @@ def atari():
dueling=True
)
def retro():
return atari()
defaults = {
'atari': atari(),
'retro': atari()
}

View File

@@ -1,5 +1,5 @@
def mujoco():
return dict(
defaults = {
'mujoco': dict(
nsteps=2048,
nminibatches=32,
lam=0.95,
@@ -10,13 +10,13 @@ def mujoco():
lr=lambda f: 3e-4 * f,
cliprange=0.2,
value_network='copy'
)
),
def atari():
return dict(
'atari': dict(
nsteps=128, nminibatches=4,
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
ent_coef=.01,
lr=lambda f : f * 2.5e-4,
cliprange=lambda f : f * 0.1,
)
}

View File

@@ -4,7 +4,7 @@ import functools
import numpy as np
import os.path as osp
import tensorflow as tf
from baselines import logger
from baselines import logger, registry
from collections import deque
from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy
@@ -15,6 +15,7 @@ from baselines.common.mpi_adam_optimizer import MpiAdamOptimizer
from mpi4py import MPI
from baselines.common.tf_util import initialize
from baselines.common.mpi_util import sync_from_root
from baselines.ppo2.defaults import defaults
class Model(object):
"""
@@ -218,6 +219,7 @@ def constfn(val):
return val
return f
@registry.register('ppo2', defaults=defaults)
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,

39
baselines/registry.py Normal file
View File

@@ -0,0 +1,39 @@
# Registry of algorithms that keeps track of algorithms supported environments and
# and fine-grained defaults for different kinds of environments (atari, retro, mujoco etc)
#
# Example usage:
#
# from baselines import registry
#
# @registry.register('fancy_algorithm', supports_vecenv=False)
# def learn(env, network):
# return
#
# for algo_name, algo_entry in registry.registry.items():
# if not algo_entry['supports_vecenv']:
# print(f'{algo_name} does not support vecenvs')
# # should print "fancy_algorithm does not support vecenvs" (among other ones)"f
from baselines import logger
registry = {}
def register(name, supports_vecenv=True, defaults={}):
def get_fn_entrypoint(fn):
import inspect
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
def _thunk(learn_fn):
old_entry = registry.get(name)
if old_entry is not None:
logger.warn('Re-registering learn function {} (old entrypoint {}, new entrypoint {}) '.format(
name, get_fn_entrypoint(old_entry['fn']), get_fn_entrypoint(learn_fn)))
registry[name] = dict(
fn = learn_fn,
supports_vecenv=supports_vecenv,
defaults=defaults,
)
return learn_fn
return _thunk

View File

@@ -3,16 +3,12 @@ import multiprocessing
import os.path as osp
import gym
from collections import defaultdict
import tensorflow as tf
import numpy as np
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session
from baselines import logger
from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines import logger
from baselines.registry import registry
try:
from mpi4py import MPI
@@ -89,28 +85,20 @@ def build_env(args):
seed = args.seed
env_type, env_id = get_env_type(args.env)
assert alg in registry, 'Unknown algorithm {}'.format(alg)
if env_type in {'atari', 'retro'}:
if alg == 'deepq':
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
elif alg == 'trpo_mpi':
env = make_env(env_id, env_type, seed=seed)
else:
frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
env = VecFrameStack(env, frame_stack_size)
frame_stack_size = 4
else:
config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
frame_stack_size = 1
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
if registry[alg]['supports_vecenv']:
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale, frame_stack_size=frame_stack_size)
else:
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': frame_stack_size > 1})
if env_type == 'mujoco':
env = VecNormalize(env)
if env_type == 'mujoco' and registry[alg]['supports_vecenv']:
env = VecNormalize(env)
return env
@@ -137,29 +125,26 @@ def get_default_network(env_type):
return 'mlp'
def get_alg_module(alg, submodule=None):
submodule = submodule or alg
try:
# first try to import the alg module from baselines
alg_module = import_module('.'.join(['baselines', alg, submodule]))
except ImportError:
# then from rl_algs
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
import inspect
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
module = inspect.getmodule(entry['fn']).__name__
if submodule is not None:
module = '.'.join([module, submodule])
return module
return alg_module
def get_learn_function(alg):
return get_alg_module(alg).learn
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['fn']
def get_learn_function_defaults(alg, env_type):
try:
alg_defaults = get_alg_module(alg, 'defaults')
kwargs = getattr(alg_defaults, env_type)()
except (ImportError, AttributeError):
kwargs = {}
return kwargs
entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['defaults'].get(env_type, {})
def parse_cmdline_kwargs(args):
@@ -193,6 +178,7 @@ def main():
rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0:

View File

@@ -28,3 +28,8 @@ def mujoco():
vf_stepsize=1e-3,
normalize_observations=True,
)
defaults = {
'atari': atari(),
'mujoco': mujoco(),
}

View File

@@ -1,5 +1,5 @@
from baselines.common import explained_variance, zipsame, dataset
from baselines import logger
from baselines import logger, registry
import baselines.common.tf_util as U
import tensorflow as tf, numpy as np
import time
@@ -13,6 +13,8 @@ from baselines.common.input import observation_placeholder
from baselines.common.policies import build_policy
from contextlib import contextmanager
from baselines.trpo_mpi.defaults import defaults
def traj_segment_generator(pi, env, horizon, stochastic):
# Initialize state variables
t = 0
@@ -82,6 +84,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
seg["tdlamret"] = seg["adv"] + seg["vpred"]
@registry.register('trpo_mpi', supports_vecenv=False, defaults=defaults)
def learn(*,
network,
env,