diff --git a/baselines/__init__.py b/baselines/__init__.py index e69de29..f4a1f1f 100644 --- a/baselines/__init__.py +++ b/baselines/__init__.py @@ -0,0 +1,12 @@ +# explicitly import sub-packages to register algorithms + +import baselines.a2c.a2c +import baselines.acer.acer +import baselines.acktr.acktr +import baselines.deepq.deepq +import baselines.ddpg.ddpg +import baselines.ppo2.ppo2 + +# not really sure why flake8 complains only about trpo_mpi here... +import baselines.trpo_mpi.trpo_mpi # noqa: F401 + diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index b0fccfb..46591c5 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -2,13 +2,12 @@ import time import functools import tensorflow as tf -from baselines import logger +from baselines import logger, registry from baselines.common import set_global_seeds, explained_variance from baselines.common import tf_util from baselines.common.policies import build_policy - from baselines.a2c.utils import Scheduler, find_trainable_variables from baselines.a2c.runner import Runner @@ -114,6 +113,7 @@ class Model(object): tf.global_variables_initializer().run(session=sess) +@registry.register('a2c') def learn( network, env, diff --git a/baselines/acer/acer.py b/baselines/acer/acer.py index 4e2e00f..7d0be08 100644 --- a/baselines/acer/acer.py +++ b/baselines/acer/acer.py @@ -2,11 +2,12 @@ import time import functools import numpy as np import tensorflow as tf -from baselines import logger +from baselines import logger, registry from baselines.common import set_global_seeds from baselines.common.policies import build_policy from baselines.common.tf_util import get_session, save_variables +from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import cat_entropy_softmax @@ -15,6 +16,7 @@ from baselines.a2c.utils import EpisodeStats from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance from baselines.acer.buffer import Buffer from baselines.acer.runner import Runner +from baselines.acer.defaults import defaults # remove last step def strip(var, nenvs, nsteps, flat = False): @@ -55,8 +57,7 @@ def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma): # return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio)) class Model(object): - def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, nstack, num_procs, - ent_coef, q_coef, gamma, max_grad_norm, lr, + def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr, rprop_alpha, rprop_epsilon, total_timesteps, lrschedule, c, trust_region, alpha, delta): @@ -71,8 +72,8 @@ class Model(object): LR = tf.placeholder(tf.float32, []) eps = 1e-6 - step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,)) - train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,)) + step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape) + train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape) with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE): step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess) @@ -247,6 +248,7 @@ class Acer(): # get obs, actions, rewards, mus, dones from buffer. obs, actions, rewards, mus, dones, masks = buffer.get() + # reshape stuff correctly obs = obs.reshape(runner.batch_ob_shape) actions = actions.reshape([runner.nbatch]) @@ -269,8 +271,8 @@ class Acer(): logger.record_tabular(name, float(val)) logger.dump_tabular() - -def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01, +@registry.register('acer', defaults=defaults) +def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01, max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99, log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0, trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs): @@ -342,21 +344,24 @@ def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6 print("Running Acer Simple") print(locals()) set_global_seeds(seed) - policy = build_policy(env, network, estimate_q=True, **network_kwargs) + if not isinstance(env, VecFrameStack): + env = VecFrameStack(env, 1) + policy = build_policy(env, network, estimate_q=True, **network_kwargs) nenvs = env.num_envs ob_space = env.observation_space ac_space = env.action_space - num_procs = len(env.remotes) if hasattr(env, 'remotes') else 1# HACK - model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack, - num_procs=num_procs, ent_coef=ent_coef, q_coef=q_coef, gamma=gamma, + + nstack = env.nstack + model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, + ent_coef=ent_coef, q_coef=q_coef, gamma=gamma, max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, trust_region=trust_region, alpha=alpha, delta=delta) - runner = Runner(env=env, model=model, nsteps=nsteps, nstack=nstack) + runner = Runner(env=env, model=model, nsteps=nsteps) if replay_ratio > 0: - buffer = Buffer(env=env, nsteps=nsteps, nstack=nstack, size=buffer_size) + buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size) else: buffer = None nbatch = nenvs*nsteps diff --git a/baselines/acer/buffer.py b/baselines/acer/buffer.py index 2dcfa10..000592c 100644 --- a/baselines/acer/buffer.py +++ b/baselines/acer/buffer.py @@ -2,11 +2,16 @@ import numpy as np class Buffer(object): # gets obs, actions, rewards, mu's, (states, masks), dones - def __init__(self, env, nsteps, nstack, size=50000): + def __init__(self, env, nsteps, size=50000): self.nenv = env.num_envs self.nsteps = nsteps - self.nh, self.nw, self.nc = env.observation_space.shape - self.nstack = nstack + # self.nh, self.nw, self.nc = env.observation_space.shape + self.obs_shape = env.observation_space.shape + self.obs_dtype = env.observation_space.dtype + self.ac_dtype = env.action_space.dtype + self.nc = self.obs_shape[-1] + self.nstack = env.nstack + self.nc //= self.nstack self.nbatch = self.nenv * self.nsteps self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames @@ -33,22 +38,11 @@ class Buffer(object): # Generate stacked frames def decode(self, enc_obs, dones): # enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc] - # dones has shape [nenvs, nsteps, nh, nw, nc] + # dones has shape [nenvs, nsteps] # returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc] - nstack, nenv, nsteps, nh, nw, nc = self.nstack, self.nenv, self.nsteps, self.nh, self.nw, self.nc - y = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32) - obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=np.uint8) - x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, - 0) # [nsteps + nstack, nenv, nh, nw, nc] - y[3:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep - y[:3] = 1.0 - # y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1]) - for i in range(nstack): - obs[-(i + 1), i:] = x - # obs[:,i:,:,:,-(i+1),:] = x - x = x[:-1] * y - y = y[1:] - return np.reshape(obs[:, 3:].transpose((2, 1, 3, 4, 0, 5)), [nenv, (nsteps + 1), nh, nw, nstack * nc]) + + return _stack_obs(enc_obs, dones, + nsteps=self.nsteps) def put(self, enc_obs, actions, rewards, mus, dones, masks): # enc_obs [nenv, (nsteps + nstack), nh, nw, nc] @@ -56,8 +50,8 @@ class Buffer(object): # mus [nenv, nsteps, nact] if self.enc_obs is None: - self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=np.uint8) - self.actions = np.empty([self.size] + list(actions.shape), dtype=np.int32) + self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype) + self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype) self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32) self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32) self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool) @@ -101,3 +95,62 @@ class Buffer(object): mus = take(self.mus) masks = take(self.masks) return obs, actions, rewards, mus, dones, masks + + + +def _stack_obs_ref(enc_obs, dones, nsteps): + nenv = enc_obs.shape[0] + nstack = enc_obs.shape[1] - nsteps + nh, nw, nc = enc_obs.shape[2:] + obs_dtype = enc_obs.dtype + obs_shape = (nh, nw, nc*nstack) + + mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32) + obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype) + x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc] + + mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep + mask[:nstack-1] = 1.0 + + # y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1]) + for i in range(nstack): + obs[-(i + 1), i:] = x + # obs[:,i:,:,:,-(i+1),:] = x + x = x[:-1] * mask + mask = mask[1:] + + return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape) + +def _stack_obs(enc_obs, dones, nsteps): + nenv = enc_obs.shape[0] + nstack = enc_obs.shape[1] - nsteps + nc = enc_obs.shape[-1] + + obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype) + mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype) + mask[:, 1:] = 1.0 - dones + mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8))) + + for i in range(nstack-1, -1, -1): + obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :] + if i < nstack-1: + obs_[..., i * nc : (i + 1) * nc] *= mask + mask[:, 1:, ...] *= mask[:, :-1, ...] + + return obs_ + +def test_stack_obs(): + nstack = 7 + nenv = 1 + nsteps = 5 + + obs_shape = (2, 3, nstack) + + enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,) + enc_obs = np.random.random(enc_obs_shape) + dones = np.random.randint(low=0, high=2, size=(nenv, nsteps)) + + stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps) + stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps) + + np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test) diff --git a/baselines/acer/defaults.py b/baselines/acer/defaults.py index 0334bae..e54220b 100644 --- a/baselines/acer/defaults.py +++ b/baselines/acer/defaults.py @@ -1,4 +1,3 @@ -def atari(): - return dict( - lrschedule='constant' - ) +defaults = { + 'atari': dict(lrschedule='constant') +} diff --git a/baselines/acer/runner.py b/baselines/acer/runner.py index 6bc1b4c..afd19ce 100644 --- a/baselines/acer/runner.py +++ b/baselines/acer/runner.py @@ -1,30 +1,31 @@ import numpy as np from baselines.common.runners import AbstractEnvRunner +from baselines.common.vec_env.vec_frame_stack import VecFrameStack +from gym import spaces + class Runner(AbstractEnvRunner): - def __init__(self, env, model, nsteps, nstack): + def __init__(self, env, model, nsteps): super().__init__(env=env, model=model, nsteps=nsteps) - self.nstack = nstack - nh, nw, nc = env.observation_space.shape - self.nc = nc # nc = 1 for atari, but just in case + assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!' + assert isinstance(env, VecFrameStack) + self.nact = env.action_space.n nenv = self.nenv self.nbatch = nenv * nsteps - self.batch_ob_shape = (nenv*(nsteps+1), nh, nw, nc*nstack) - self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8) - obs = env.reset() - self.update_obs(obs) + self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape + + self.obs = env.reset() + self.obs_dtype = env.observation_space.dtype + self.ac_dtype = env.action_space.dtype + self.nstack = self.env.nstack + self.nc = self.batch_ob_shape[-1] // self.nstack - def update_obs(self, obs, dones=None): - #self.obs = obs - if dones is not None: - self.obs *= (1 - dones.astype(np.uint8))[:, None, None, None] - self.obs = np.roll(self.obs, shift=-self.nc, axis=3) - self.obs[:, :, :, -self.nc:] = obs[:, :, :, :] def run(self): - enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps + # enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps + enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1) mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], [] for _ in range(self.nsteps): actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones) @@ -36,15 +37,15 @@ class Runner(AbstractEnvRunner): # states information for statefull models like LSTM self.states = states self.dones = dones - self.update_obs(obs, dones) + self.obs = obs mb_rewards.append(rewards) - enc_obs.append(obs) + enc_obs.append(obs[..., -self.nc:]) mb_obs.append(np.copy(self.obs)) mb_dones.append(self.dones) - enc_obs = np.asarray(enc_obs, dtype=np.uint8).swapaxes(1, 0) - mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0) - mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0) + enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0) + mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0) + mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0) mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0) diff --git a/baselines/acktr/acktr.py b/baselines/acktr/acktr.py index dcbe612..c5ac025 100644 --- a/baselines/acktr/acktr.py +++ b/baselines/acktr/acktr.py @@ -2,7 +2,7 @@ import os.path as osp import time import functools import tensorflow as tf -from baselines import logger +from baselines import logger, registry from baselines.common import set_global_seeds, explained_variance from baselines.common.policies import build_policy @@ -11,6 +11,7 @@ from baselines.common.tf_util import get_session, save_variables, load_variables from baselines.a2c.runner import Runner from baselines.a2c.utils import Scheduler, find_trainable_variables from baselines.acktr import kfac +from baselines.acktr.defaults import defaults class Model(object): @@ -90,6 +91,7 @@ class Model(object): self.initial_state = step_model.initial_state tf.global_variables_initializer().run(session=sess) +@registry.register('acktr', defaults=defaults) def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs): diff --git a/baselines/acktr/defaults.py b/baselines/acktr/defaults.py index f1e3ab6..5d06d4b 100644 --- a/baselines/acktr/defaults.py +++ b/baselines/acktr/defaults.py @@ -1,5 +1,6 @@ -def mujoco(): - return dict( +defaults = { + 'mujoco' : dict( nsteps=2500, value_network='copy' ) +} diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index d69589c..316b893 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -16,30 +16,64 @@ from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv -from baselines.common.retro_wrappers import RewardScaler +from baselines.common.vec_env.vec_frame_stack import VecFrameStack +from baselines.common import retro_wrappers -def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0): +def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None, frame_stack_size=1): """ - Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo. + Create a wrapped, monitored SubprocVecEnv """ if wrapper_kwargs is None: wrapper_kwargs = {} mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 - def make_env(rank): # pylint: disable=C0111 - def _thunk(): - env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id) - env.seed(seed + 10000*mpi_rank + rank if seed is not None else None) - env = Monitor(env, - logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)), - allow_early_resets=True) + seed = seed + 10000 * mpi_rank if seed is not None else None + def make_thunk(rank): + return lambda: make_env( + env_id=env_id, + env_type=env_type, + subrank = rank, + seed=seed, + reward_scale=reward_scale, + gamestate=gamestate + ) - if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs) - elif reward_scale != 1: return RewardScaler(env, reward_scale) - else: return env - return _thunk set_global_seeds(seed) - if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)]) - else: return DummyVecEnv([make_env(start_index)]) + if num_env > 1: + venv = SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)]) + else: + venv = DummyVecEnv([make_thunk(start_index)]) + + if frame_stack_size > 1: + venv = VecFrameStack(venv, frame_stack_size) + + return venv + + + +def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}): + mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 + if env_type == 'atari': + env = make_atari(env_id) + elif env_type == 'retro': + import retro + gamestate = gamestate or retro.State.DEFAULT + env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate) + else: + env = gym.make(env_id) + + env.seed(seed + subrank if seed is not None else None) + env = Monitor(env, + logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)), + allow_early_resets=True) + + if env_type == 'atari': + return wrap_deepmind(env, **wrapper_kwargs) + elif reward_scale != 1: + return retro_wrappers.RewardScaler(env, reward_scale) + else: + return env + + def make_mujoco_env(env_id, seed, reward_scale=1.0): """ diff --git a/baselines/common/filters.py b/baselines/common/filters.py deleted file mode 100644 index 5ce019c..0000000 --- a/baselines/common/filters.py +++ /dev/null @@ -1,98 +0,0 @@ -from .running_stat import RunningStat -from collections import deque -import numpy as np - -class Filter(object): - def __call__(self, x, update=True): - raise NotImplementedError - def reset(self): - pass - -class IdentityFilter(Filter): - def __call__(self, x, update=True): - return x - -class CompositionFilter(Filter): - def __init__(self, fs): - self.fs = fs - def __call__(self, x, update=True): - for f in self.fs: - x = f(x) - return x - def output_shape(self, input_space): - out = input_space.shape - for f in self.fs: - out = f.output_shape(out) - return out - -class ZFilter(Filter): - """ - y = (x-mean)/std - using running estimates of mean,std - """ - - def __init__(self, shape, demean=True, destd=True, clip=10.0): - self.demean = demean - self.destd = destd - self.clip = clip - - self.rs = RunningStat(shape) - - def __call__(self, x, update=True): - if update: self.rs.push(x) - if self.demean: - x = x - self.rs.mean - if self.destd: - x = x / (self.rs.std+1e-8) - if self.clip: - x = np.clip(x, -self.clip, self.clip) - return x - def output_shape(self, input_space): - return input_space.shape - -class AddClock(Filter): - def __init__(self): - self.count = 0 - def reset(self): - self.count = 0 - def __call__(self, x, update=True): - return np.append(x, self.count/100.0) - def output_shape(self, input_space): - return (input_space.shape[0]+1,) - -class FlattenFilter(Filter): - def __call__(self, x, update=True): - return x.ravel() - def output_shape(self, input_space): - return (int(np.prod(input_space.shape)),) - -class Ind2OneHotFilter(Filter): - def __init__(self, n): - self.n = n - def __call__(self, x, update=True): - out = np.zeros(self.n) - out[x] = 1 - return out - def output_shape(self, input_space): - return (input_space.n,) - -class DivFilter(Filter): - def __init__(self, divisor): - self.divisor = divisor - def __call__(self, x, update=True): - return x / self.divisor - def output_shape(self, input_space): - return input_space.shape - -class StackFilter(Filter): - def __init__(self, length): - self.stack = deque(maxlen=length) - def reset(self): - self.stack.clear() - def __call__(self, x, update=True): - self.stack.append(x) - while len(self.stack) < self.stack.maxlen: - self.stack.append(x) - return np.concatenate(self.stack, axis=-1) - def output_shape(self, input_space): - return input_space.shape[:-1] + (input_space.shape[-1] * self.stack.maxlen,) diff --git a/baselines/common/running_stat.py b/baselines/common/running_stat.py deleted file mode 100644 index b9aa86c..0000000 --- a/baselines/common/running_stat.py +++ /dev/null @@ -1,46 +0,0 @@ -import numpy as np - -# http://www.johndcook.com/blog/standard_deviation/ -class RunningStat(object): - def __init__(self, shape): - self._n = 0 - self._M = np.zeros(shape) - self._S = np.zeros(shape) - def push(self, x): - x = np.asarray(x) - assert x.shape == self._M.shape - self._n += 1 - if self._n == 1: - self._M[...] = x - else: - oldM = self._M.copy() - self._M[...] = oldM + (x - oldM)/self._n - self._S[...] = self._S + (x - oldM)*(x - self._M) - @property - def n(self): - return self._n - @property - def mean(self): - return self._M - @property - def var(self): - return self._S/(self._n - 1) if self._n > 1 else np.square(self._M) - @property - def std(self): - return np.sqrt(self.var) - @property - def shape(self): - return self._M.shape - -def test_running_stat(): - for shp in ((), (3,), (3,4)): - li = [] - rs = RunningStat(shp) - for _ in range(5): - val = np.random.randn(*shp) - rs.push(val) - li.append(val) - m = np.mean(li, axis=0) - assert np.allclose(rs.mean, m) - v = np.square(m) if (len(li) == 1) else np.var(li, ddof=1, axis=0) - assert np.allclose(rs.var, v) diff --git a/baselines/common/tests/test_cartpole.py b/baselines/common/tests/test_cartpole.py index 06d65e4..475ad1d 100644 --- a/baselines/common/tests/test_cartpole.py +++ b/baselines/common/tests/test_cartpole.py @@ -13,6 +13,7 @@ common_kwargs = dict( learn_kwargs = { 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), + 'acer': dict(value_network='copy'), 'acktr': dict(nsteps=32, value_network='copy', is_async=False), 'deepq': dict(total_timesteps=20000), 'ppo2': dict(value_network='copy'), @@ -40,4 +41,4 @@ def test_cartpole(alg): reward_per_episode_test(env_fn, learn_fn, 100) if __name__ == '__main__': - test_cartpole('deepq') + test_cartpole('acer') diff --git a/baselines/common/tests/test_identity.py b/baselines/common/tests/test_identity.py index 744ed83..0b3c46e 100644 --- a/baselines/common/tests/test_identity.py +++ b/baselines/common/tests/test_identity.py @@ -20,8 +20,8 @@ learn_kwargs = { } -algos_disc = ['a2c', 'deepq', 'ppo2', 'trpo_mpi'] -algos_cont = ['a2c', 'ddpg', 'ppo2', 'trpo_mpi'] +algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi'] +algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi'] @pytest.mark.slow @pytest.mark.parametrize("alg", algos_disc) diff --git a/baselines/common/tests/test_mnist.py b/baselines/common/tests/test_mnist.py index 536164f..eea094d 100644 --- a/baselines/common/tests/test_mnist.py +++ b/baselines/common/tests/test_mnist.py @@ -17,8 +17,7 @@ common_kwargs = { learn_args = { 'a2c': dict(total_timesteps=50000), - # TODO need to resolve inference (step) API differences for acer; also slow - # 'acer': dict(seed=0, total_timesteps=1000), + 'acer': dict(total_timesteps=20000), 'deepq': dict(total_timesteps=5000), 'acktr': dict(total_timesteps=30000), 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), @@ -47,4 +46,4 @@ def test_mnist(alg): simple_test(env_fn, learn_fn, 0.6) if __name__ == '__main__': - test_mnist('deepq') + test_mnist('acer') diff --git a/baselines/common/tests/test_serialization.py b/baselines/common/tests/test_serialization.py index f46b578..fac4929 100644 --- a/baselines/common/tests/test_serialization.py +++ b/baselines/common/tests/test_serialization.py @@ -17,6 +17,7 @@ learn_kwargs = { 'deepq': {}, 'a2c': {}, 'acktr': {}, + 'acer': {}, 'ppo2': {'nminibatches': 1, 'nsteps': 10}, 'trpo_mpi': {}, } @@ -37,7 +38,7 @@ def test_serialization(learn_fn, network_fn): ''' - if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: + if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']: # TODO make acktr work with recurrent policies # and test # github issue: https://github.com/openai/baselines/issues/660 diff --git a/baselines/ddpg/ddpg.py b/baselines/ddpg/ddpg.py index 181f923..db64fb4 100755 --- a/baselines/ddpg/ddpg.py +++ b/baselines/ddpg/ddpg.py @@ -10,11 +10,11 @@ from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, Orns import baselines.common.tf_util as U -from baselines import logger +from baselines import logger, registry import numpy as np from mpi4py import MPI - +@registry.register('ddpg') def learn(network, env, seed=None, total_timesteps=None, diff --git a/baselines/deepq/deepq.py b/baselines/deepq/deepq.py index c6004b2..4108210 100644 --- a/baselines/deepq/deepq.py +++ b/baselines/deepq/deepq.py @@ -8,7 +8,7 @@ import numpy as np import baselines.common.tf_util as U from baselines.common.tf_util import load_variables, save_variables -from baselines import logger +from baselines import logger, registry from baselines.common.schedules import LinearSchedule from baselines.common import set_global_seeds @@ -18,6 +18,7 @@ from baselines.deepq.utils import ObservationInput from baselines.common.tf_util import get_session from baselines.deepq.models import build_q_func +from baselines.deepq.defaults import defaults class ActWrapper(object): @@ -92,6 +93,7 @@ def load_act(path): return ActWrapper.load_act(path) +@registry.register('deepq', supports_vecenvs=False, defaults=defaults) def learn(env, network, seed=None, @@ -124,16 +126,12 @@ def learn(env, ------- env: gym.Env environment to train on - q_func: (tf.Variable, int, str, bool) -> tf.Variable - the model that takes the following inputs: - observation_in: object - the output of observation placeholder - num_actions: int - number of actions - scope: str - reuse: bool - should be passed to outer variable scope - and returns a tensor of shape (batch_size, num_actions) with values of every action. + network: string or a function + neural network to use as a q function approximator. If string, has to be one of the names of registered models in baselines.common.models + (mlp, cnn, conv_only). If a function, should take an observation tensor and return a latent variable tensor, which + will be mapped to the Q function heads (see build_q_func in baselines.deepq.models for details on that) + seed: int or None + prng seed. The runs with the same seed "should" give the same results. If None, no seeding is used. lr: float learning rate for adam optimizer total_timesteps: int diff --git a/baselines/deepq/defaults.py b/baselines/deepq/defaults.py index d41fb18..06abdf4 100644 --- a/baselines/deepq/defaults.py +++ b/baselines/deepq/defaults.py @@ -16,6 +16,8 @@ def atari(): dueling=True ) -def retro(): - return atari() +defaults = { + 'atari': atari() + 'retro': atari() +} diff --git a/baselines/her/README.md b/baselines/her/README.md index 6bd02b4..9934c69 100644 --- a/baselines/her/README.md +++ b/baselines/her/README.md @@ -30,3 +30,51 @@ python -m baselines.her.experiment.train --num_cpu 19 This will require a machine with sufficient amount of physical CPU cores. In our experiments, we used [Azure's D15v2 instances](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes), which have 20 physical cores. We only scheduled the experiment on 19 of those to leave some head-room on the system. + + +## Hindsight Experience Replay with Demonstrations +Using pre-recorded demonstrations to Overcome the exploration problem in HER based Reinforcement learning. +For details, please read the [paper](https://arxiv.org/pdf/1709.10089.pdf). + +### Getting started +The first step is to generate the demonstration dataset. This can be done in two ways, either by using a VR system to manipulate the arm using physical VR trackers or the simpler way is to write a script to carry out the respective task. Now some tasks can be complex and thus it would be difficult to write a hardcoded script for that task (eg. Fetch Push), but here our focus is on providing an algorithm that helps the agent to learn from demonstrations, and not on the demonstration generation paradigm itself. Thus the data collection part is left to the reader's choice. + +We provide a script for the Fetch Pick and Place task, to generate demonstrations for the Pick and Place task execute: +```bash +python experiment/data_generation/fetch_data_generation.py +``` +This outputs ```data_fetch_random_100.npz``` file which is our data file. + +#### Configuration +The provided configuration is for training an agent with HER without demonstrations, we need to change a few paramters for the HER algorithm to learn through demonstrations, to do that, set: + +* bc_loss: 1 - whether or not to use the behavior cloning loss as an auxilliary loss +* q_filter: 1 - whether or not a Q value filter should be used on the Actor outputs +* num_demo: 100 - number of expert demo episodes +* demo_batch_size: 128 - number of samples to be used from the demonstrations buffer, per mpi thread +* prm_loss_weight: 0.001 - Weight corresponding to the primary loss +* aux_loss_weight: 0.0078 - Weight corresponding to the auxilliary loss also called the cloning loss + +Apart from these changes the reported results also have the following configurational changes: + +* n_cycles: 20 - per epoch +* batch_size: 1024 - per mpi thread, total batch size +* random_eps: 0.1 - percentage of time a random action is taken +* noise_eps: 0.1 - std of gaussian noise added to not-completely-random actions + +Now training an agent with pre-recorded demonstrations: +```bash +python -m baselines.her.experiment.train --env=FetchPickAndPlace-v0 --n_epochs=1000 --demo_file=/Path/to/demo_file.npz --num_cpu=1 +``` + +This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data. +To inspect what the agent has learned, use the play script as described above. + +### Results +Training with demonstrations helps overcome the exploration problem and achieves a faster and better convergence. The following graphs contrast the difference between training with and without demonstration data, We report the mean Q values vs Epoch and the Success Rate vs Epoch: + + +