diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index 186c522..c28fe65 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -1,16 +1,12 @@ -import os import os.path as osp -import gym import time import joblib -import logging import numpy as np import tensorflow as tf from baselines import logger from baselines.common import set_global_seeds, explained_variance -from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv -from baselines.common.atari_wrappers import wrap_deepmind +from baselines.common.runners import AbstractEnvRunner from baselines.common import tf_util from baselines.a2c.utils import discount_with_dones @@ -24,7 +20,6 @@ class Model(object): alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'): sess = tf_util.make_session() - nact = ac_space.n nbatch = nenvs*nsteps A = tf.placeholder(tf.int32, [nbatch]) @@ -75,7 +70,7 @@ class Model(object): restores = [] for p, loaded_p in zip(params, loaded_params): restores.append(p.assign(loaded_p)) - ps = sess.run(restores) + sess.run(restores) self.train = train self.train_model = train_model @@ -87,21 +82,11 @@ class Model(object): self.load = load tf.global_variables_initializer().run(session=sess) -class Runner(object): +class Runner(AbstractEnvRunner): def __init__(self, env, model, nsteps=5, gamma=0.99): - self.env = env - self.model = model - nh, nw, nc = env.observation_space.shape - nenv = env.num_envs - self.batch_ob_shape = (nenv*nsteps, nh, nw, nc) - self.obs = np.zeros((nenv, nh, nw, nc), dtype=np.uint8) - self.nc = nc - obs = env.reset() + super().__init__(env=env, model=model, nsteps=nsteps) self.gamma = gamma - self.nsteps = nsteps - self.states = model.initial_state - self.dones = [False for _ in range(nenv)] def run(self): mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] diff --git a/baselines/a2c/policies.py b/baselines/a2c/policies.py index 9b2a627..172a3ee 100644 --- a/baselines/a2c/policies.py +++ b/baselines/a2c/policies.py @@ -3,15 +3,16 @@ import tensorflow as tf from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.common.distributions import make_pdtype -def nature_cnn(unscaled_images): +def nature_cnn(unscaled_images, **conv_kwargs): """ CNN from Nature paper. """ scaled_images = tf.cast(unscaled_images, tf.float32) / 255. activ = tf.nn.relu - h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))) - h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))) - h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))) + h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), + **conv_kwargs)) + h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs)) + h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs)) h3 = conv_to_fc(h3) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) @@ -20,21 +21,18 @@ class LnLstmPolicy(object): nenv = nbatch // nsteps nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states + self.pdtype = make_pdtype(ac_space) with tf.variable_scope("model", reuse=reuse): h = nature_cnn(X) xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact) vf = fc(h5, 'v', 1) - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h5) v0 = vf[:, 0] a0 = self.pd.sample() @@ -50,7 +48,6 @@ class LnLstmPolicy(object): self.X = X self.M = M self.S = S - self.pi = pi self.vf = vf self.step = step self.value = value @@ -62,7 +59,7 @@ class LstmPolicy(object): nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states @@ -72,11 +69,8 @@ class LstmPolicy(object): ms = batch_to_seq(M, nenv, nsteps) h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact) vf = fc(h5, 'v', 1) - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h5) v0 = vf[:, 0] a0 = self.pd.sample() @@ -92,25 +86,21 @@ class LstmPolicy(object): self.X = X self.M = M self.S = S - self.pi = pi self.vf = vf self.step = step self.value = value class CnnPolicy(object): - def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613 nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.uint8, ob_shape) #obs with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X) - pi = fc(h, 'pi', nact, init_scale=0.01) + h = nature_cnn(X, **conv_kwargs) vf = fc(h, 'v', 1)[:,0] - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01) a0 = self.pd.sample() neglogp0 = self.pd.neglogp(a0) @@ -124,7 +114,6 @@ class CnnPolicy(object): return sess.run(vf, {X:ob}) self.X = X - self.pi = pi self.vf = vf self.step = step self.value = value @@ -132,23 +121,19 @@ class CnnPolicy(object): class MlpPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 ob_shape = (nbatch,) + ob_space.shape - actdim = ac_space.shape[0] + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs with tf.variable_scope("model", reuse=reuse): activ = tf.tanh - h1 = activ(fc(X, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) - h2 = activ(fc(h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) - pi = fc(h2, 'pi', actdim, init_scale=0.01) - h1 = activ(fc(X, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) - h2 = activ(fc(h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) - vf = fc(h2, 'vf', 1)[:,0] - logstd = tf.get_variable(name="logstd", shape=[1, actdim], - initializer=tf.zeros_initializer()) + flatten = tf.layers.flatten + pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2))) + pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) + vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2))) + vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) + vf = fc(vf_h2, 'vf', 1)[:,0] - pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1) + self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01) - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pdparam) a0 = self.pd.sample() neglogp0 = self.pd.neglogp(a0) @@ -162,7 +147,6 @@ class MlpPolicy(object): return sess.run(vf, {X:ob}) self.X = X - self.pi = pi self.vf = vf self.step = step self.value = value diff --git a/baselines/a2c/utils.py b/baselines/a2c/utils.py index 0964af8..a7610eb 100644 --- a/baselines/a2c/utils.py +++ b/baselines/a2c/utils.py @@ -39,7 +39,7 @@ def ortho_init(scale=1.0): return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return _ortho_init -def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'): +def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False): if data_format == 'NHWC': channel_ax = 3 strides = [1, stride, stride, 1] @@ -50,12 +50,14 @@ def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format=' bshape = [1, nf, 1, 1] else: raise NotImplementedError + bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1] nin = x.get_shape()[channel_ax].value wshape = [rf, rf, nin, nf] with tf.variable_scope(scope): w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) - b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0)) - if data_format == 'NHWC': b = tf.reshape(b, bshape) + b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0)) + if not one_dim_bias and data_format == 'NHWC': + b = tf.reshape(b, bshape) return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): diff --git a/baselines/acer/acer_simple.py b/baselines/acer/acer_simple.py index 5690294..bed486a 100644 --- a/baselines/acer/acer_simple.py +++ b/baselines/acer/acer_simple.py @@ -5,6 +5,7 @@ import tensorflow as tf from baselines import logger from baselines.common import set_global_seeds +from baselines.common.runners import AbstractEnvRunner from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables @@ -13,6 +14,8 @@ from baselines.a2c.utils import EpisodeStats from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance from baselines.acer.buffer import Buffer +import os.path as osp + # remove last step def strip(var, nenvs, nsteps, flat = False): vars = batch_to_seq(var, nenvs, nsteps + 1, flat) @@ -209,11 +212,10 @@ class Model(object): self.initial_state = step_model.initial_state tf.global_variables_initializer().run(session=sess) -class Runner(object): +class Runner(AbstractEnvRunner): def __init__(self, env, model, nsteps, nstack): - self.env = env + super().__init__(env=env, model=model, nsteps=nsteps) self.nstack = nstack - self.model = model nh, nw, nc = env.observation_space.shape self.nc = nc # nc = 1 for atari, but just in case self.nenv = nenv = env.num_envs @@ -223,9 +225,6 @@ class Runner(object): self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8) obs = env.reset() self.update_obs(obs) - self.nsteps = nsteps - self.states = model.initial_state - self.dones = [False for _ in range(nenv)] def update_obs(self, obs, dones=None): if dones is not None: diff --git a/baselines/acktr/run_atari.py b/baselines/acktr/run_atari.py index 7569f2e..6e398ce 100644 --- a/baselines/acktr/run_atari.py +++ b/baselines/acktr/run_atari.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 +from functools import partial + from baselines import logger from baselines.acktr.acktr_disc import learn from baselines.common.cmd_util import make_atari_env, atari_arg_parser @@ -8,7 +10,7 @@ from baselines.ppo2.policies import CnnPolicy def train(env_id, num_timesteps, seed, num_cpu): env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) - policy_fn = CnnPolicy + policy_fn = partial(CnnPolicy, one_dim_bias=True) learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu) env.close() diff --git a/baselines/common/__init__.py b/baselines/common/__init__.py index 4ab0604..0834b36 100644 --- a/baselines/common/__init__.py +++ b/baselines/common/__init__.py @@ -1,3 +1,4 @@ +# flake8: noqa F403 from baselines.common.console_util import * from baselines.common.dataset import Dataset from baselines.common.math_util import * diff --git a/baselines/common/atari_wrappers.py b/baselines/common/atari_wrappers.py index 0901378..2aefad7 100644 --- a/baselines/common/atari_wrappers.py +++ b/baselines/common/atari_wrappers.py @@ -98,9 +98,6 @@ class MaxAndSkipEnv(gym.Wrapper): self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) self._skip = skip - def reset(self): - return self.env.reset() - def step(self, action): """Repeat action, sum reward, and max over last observations.""" total_reward = 0.0 diff --git a/baselines/common/cmd_util.py b/baselines/common/cmd_util.py index 8a3304b..b88d529 100644 --- a/baselines/common/cmd_util.py +++ b/baselines/common/cmd_util.py @@ -10,7 +10,6 @@ from baselines.bench import Monitor from baselines.common import set_global_seeds from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv -from mpi4py import MPI def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): """ diff --git a/baselines/common/distributions.py b/baselines/common/distributions.py index 6f5b522..8a57c37 100644 --- a/baselines/common/distributions.py +++ b/baselines/common/distributions.py @@ -1,6 +1,7 @@ import tensorflow as tf import numpy as np import baselines.common.tf_util as U +from baselines.a2c.utils import fc from tensorflow.python.ops import math_ops class Pd(object): @@ -31,6 +32,8 @@ class PdType(object): raise NotImplementedError def pdfromflat(self, flat): return self.pdclass()(flat) + def pdfromlatent(self, latent_vector): + raise NotImplementedError def param_shape(self): raise NotImplementedError def sample_shape(self): @@ -48,6 +51,10 @@ class CategoricalPdType(PdType): self.ncat = ncat def pdclass(self): return CategoricalPd + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias) + return self.pdfromflat(pdparam), pdparam + def param_shape(self): return [self.ncat] def sample_shape(self): @@ -75,6 +82,13 @@ class DiagGaussianPdType(PdType): self.size = size def pdclass(self): return DiagGaussianPd + + def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0): + mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias) + logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer()) + pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1) + return self.pdfromflat(pdparam), mean + def param_shape(self): return [2*self.size] def sample_shape(self): diff --git a/baselines/common/mpi_moments.py b/baselines/common/mpi_moments.py index d13cc2f..7fcc6cd 100644 --- a/baselines/common/mpi_moments.py +++ b/baselines/common/mpi_moments.py @@ -2,6 +2,7 @@ from mpi4py import MPI import numpy as np from baselines.common import zipsame + def mpi_mean(x, axis=0, comm=None, keepdims=False): x = np.asarray(x) assert x.ndim > 0 diff --git a/baselines/common/segment_tree.py b/baselines/common/segment_tree.py index a5a7dfc..cb386ec 100644 --- a/baselines/common/segment_tree.py +++ b/baselines/common/segment_tree.py @@ -12,10 +12,9 @@ class SegmentTree(object): a) setting item's value is slightly slower. It is O(lg capacity) instead of O(1). - b) user has access to an efficient `reduce` - operation which reduces `operation` over - a contiguous subsequence of items in the - array. + b) user has access to an efficient ( O(log segment size) ) + `reduce` operation which reduces `operation` over + a contiguous subsequence of items in the array. Paramters --------- @@ -23,8 +22,8 @@ class SegmentTree(object): Total size of the array - must be a power of two. operation: lambda obj, obj -> obj and operation for combining elements (eg. sum, max) - must for a mathematical group together with the set of - possible values for array elements. + must form a mathematical group together with the set of + possible values for array elements (i.e. be associative) neutral_element: obj neutral element for the operation above. eg. float('-inf') for max and 0 for sum. diff --git a/baselines/common/tests/test_tf_util.py b/baselines/common/tests/test_tf_util.py index 8a92fa1..daad9d0 100644 --- a/baselines/common/tests/test_tf_util.py +++ b/baselines/common/tests/test_tf_util.py @@ -33,7 +33,6 @@ def test_multikwargs(): initialize() assert lin(2) == 6 assert lin(2, 2) == 10 - expt_caught = False if __name__ == '__main__': diff --git a/baselines/common/tf_util.py b/baselines/common/tf_util.py index 3396bb8..9b2822b 100644 --- a/baselines/common/tf_util.py +++ b/baselines/common/tf_util.py @@ -48,7 +48,7 @@ def huber_loss(x, delta=1.0): # Global session # ================================================================ -def make_session(num_cpu=None, make_default=False): +def make_session(num_cpu=None, make_default=False, graph=None): """Returns a session that will use CPU's only""" if num_cpu is None: num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count())) @@ -57,9 +57,9 @@ def make_session(num_cpu=None, make_default=False): intra_op_parallelism_threads=num_cpu) tf_config.gpu_options.allocator_type = 'BFC' if make_default: - return tf.InteractiveSession(config=tf_config) + return tf.InteractiveSession(config=tf_config, graph=graph) else: - return tf.Session(config=tf_config) + return tf.Session(config=tf_config, graph=graph) def single_threaded_session(): """Returns a session which will only use a single CPU""" @@ -84,10 +84,10 @@ def initialize(): # Model components # ================================================================ -def normc_initializer(std=1.0): +def normc_initializer(std=1.0, axis=0): def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613 out = np.random.randn(*shape).astype(np.float32) - out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) + out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True)) return tf.constant(out) return _initializer @@ -273,8 +273,9 @@ def display_var_info(vars): for v in vars: name = v.name if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue - count_params += np.prod(v.shape.as_list()) - if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print - logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape))) - logger.info("Total model parameters: %0.1f million" % (count_params*1e-6)) + v_params = np.prod(v.shape.as_list()) + count_params += v_params + if "/b:" in name or "/biases" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print + logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape))) + logger.info("Total model parameters: %0.2f million" % (count_params*1e-6)) diff --git a/baselines/common/vec_env/__init__.py b/baselines/common/vec_env/__init__.py index 2af2d37..146ca87 100644 --- a/baselines/common/vec_env/__init__.py +++ b/baselines/common/vec_env/__init__.py @@ -80,6 +80,13 @@ class VecEnv(ABC): def render(self): logger.warn('Render not defined for %s'%self) + @property + def unwrapped(self): + if isinstance(self, VecEnvWrapper): + return self.venv.unwrapped + else: + return self + class VecEnvWrapper(VecEnv): def __init__(self, venv, observation_space=None, action_space=None): self.venv = venv diff --git a/baselines/common/vec_env/dummy_vec_env.py b/baselines/common/vec_env/dummy_vec_env.py index edabf25..f966d29 100644 --- a/baselines/common/vec_env/dummy_vec_env.py +++ b/baselines/common/vec_env/dummy_vec_env.py @@ -1,5 +1,6 @@ import numpy as np -import gym +from gym import spaces +from collections import OrderedDict from . import VecEnv class DummyVecEnv(VecEnv): @@ -7,9 +8,22 @@ class DummyVecEnv(VecEnv): self.envs = [fn() for fn in env_fns] env = self.envs[0] VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) - - obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,) - self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces] + shapes, dtypes = {}, {} + self.keys = [] + obs_space = env.observation_space + if isinstance(obs_space, spaces.Dict): + assert isinstance(obs_space.spaces, OrderedDict) + for key, box in obs_space.spaces.items(): + assert isinstance(box, spaces.Box) + shapes[key] = box.shape + dtypes[key] = box.dtype + self.keys.append(key) + else: + box = obs_space + assert isinstance(box, spaces.Box) + self.keys = [None] + shapes, dtypes = { None: box.shape }, { None: box.dtype } + self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys } self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_infos = [{} for _ in range(self.num_envs)] @@ -19,33 +33,32 @@ class DummyVecEnv(VecEnv): self.actions = actions def step_wait(self): - for i in range(self.num_envs): - obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i]) - if self.buf_dones[i]: - obs_tuple = self.envs[i].reset() - if isinstance(obs_tuple, (tuple, list)): - for t,x in enumerate(obs_tuple): - self.buf_obs[t][i] = x - else: - self.buf_obs[0][i] = obs_tuple + for e in range(self.num_envs): + obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e]) + if self.buf_dones[e]: + obs = self.envs[e].reset() + self._save_obs(e, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), self.buf_infos.copy()) def reset(self): - for i in range(self.num_envs): - obs_tuple = self.envs[i].reset() - if isinstance(obs_tuple, (tuple, list)): - for t,x in enumerate(obs_tuple): - self.buf_obs[t][i] = x - else: - self.buf_obs[0][i] = obs_tuple + for e in range(self.num_envs): + obs = self.envs[e].reset() + self._save_obs(e, obs) return self._obs_from_buf() def close(self): return + def _save_obs(self, e, obs): + for k in self.keys: + if k is None: + self.buf_obs[k][e] = obs + else: + self.buf_obs[k][e] = obs[k] + def _obs_from_buf(self): - if len(self.buf_obs) == 1: - return np.copy(self.buf_obs[0]) + if self.keys==[None]: + return self.buf_obs[None] else: - return tuple(np.copy(x) for x in self.buf_obs) + return self.buf_obs diff --git a/baselines/deepq/experiments/run_atari.py b/baselines/deepq/experiments/run_atari.py index 7816a23..9836268 100644 --- a/baselines/deepq/experiments/run_atari.py +++ b/baselines/deepq/experiments/run_atari.py @@ -5,11 +5,13 @@ import argparse from baselines import logger from baselines.common.atari_wrappers import make_atari + def main(): parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--prioritized', type=int, default=1) + parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6) parser.add_argument('--dueling', type=int, default=1) parser.add_argument('--num-timesteps', type=int, default=int(10e6)) args = parser.parse_args() @@ -23,7 +25,8 @@ def main(): hiddens=[256], dueling=bool(args.dueling), ) - act = deepq.learn( + + deepq.learn( env, q_func=model, lr=1e-4, @@ -35,9 +38,10 @@ def main(): learning_starts=10000, target_network_update_freq=1000, gamma=0.99, - prioritized_replay=bool(args.prioritized) + prioritized_replay=bool(args.prioritized), + prioritized_replay_alpha=args.prioritized_replay_alpha ) - # act.save("pong_model.pkl") XXX + env.close() diff --git a/baselines/deepq/replay_buffer.py b/baselines/deepq/replay_buffer.py index f3dd6fb..7988113 100644 --- a/baselines/deepq/replay_buffer.py +++ b/baselines/deepq/replay_buffer.py @@ -86,7 +86,7 @@ class PrioritizedReplayBuffer(ReplayBuffer): ReplayBuffer.__init__ """ super(PrioritizedReplayBuffer, self).__init__(size) - assert alpha > 0 + assert alpha >= 0 self._alpha = alpha it_capacity = 1 diff --git a/baselines/logger.py b/baselines/logger.py index 0d24e6f..888db76 100644 --- a/baselines/logger.py +++ b/baselines/logger.py @@ -8,7 +8,8 @@ import datetime import tempfile from collections import defaultdict -LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] +LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] +LOG_OUTPUT_FORMATS_MPI = ['log'] # Also valid: json, tensorboard DEBUG = 10 @@ -355,10 +356,21 @@ def configure(dir=None, format_strs=None): assert isinstance(dir, str) os.makedirs(dir, exist_ok=True) + log_suffix = '' + from mpi4py import MPI + rank = MPI.COMM_WORLD.Get_rank() + if rank > 0: + log_suffix = "-rank%03i" % rank + if format_strs is None: - strs = os.getenv('OPENAI_LOG_FORMAT') - format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS - output_formats = [make_output_format(f, dir) for f in format_strs] + strs, strs_mpi = os.getenv('OPENAI_LOG_FORMAT'), os.getenv('OPENAI_LOG_FORMAT_MPI') + format_strs = strs_mpi if rank>0 else strs + if format_strs is not None: + format_strs = format_strs.split(',') + else: + format_strs = LOG_OUTPUT_FORMATS_MPI if rank>0 else LOG_OUTPUT_FORMATS + + output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) log('Logging to %s'%dir) diff --git a/baselines/ppo2/policies.py b/baselines/ppo2/policies.py index 9b2a627..172a3ee 100644 --- a/baselines/ppo2/policies.py +++ b/baselines/ppo2/policies.py @@ -3,15 +3,16 @@ import tensorflow as tf from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.common.distributions import make_pdtype -def nature_cnn(unscaled_images): +def nature_cnn(unscaled_images, **conv_kwargs): """ CNN from Nature paper. """ scaled_images = tf.cast(unscaled_images, tf.float32) / 255. activ = tf.nn.relu - h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))) - h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))) - h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))) + h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), + **conv_kwargs)) + h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs)) + h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs)) h3 = conv_to_fc(h3) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) @@ -20,21 +21,18 @@ class LnLstmPolicy(object): nenv = nbatch // nsteps nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states + self.pdtype = make_pdtype(ac_space) with tf.variable_scope("model", reuse=reuse): h = nature_cnn(X) xs = batch_to_seq(h, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact) vf = fc(h5, 'v', 1) - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h5) v0 = vf[:, 0] a0 = self.pd.sample() @@ -50,7 +48,6 @@ class LnLstmPolicy(object): self.X = X self.M = M self.S = S - self.pi = pi self.vf = vf self.step = step self.value = value @@ -62,7 +59,7 @@ class LstmPolicy(object): nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.uint8, ob_shape) #obs M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states @@ -72,11 +69,8 @@ class LstmPolicy(object): ms = batch_to_seq(M, nenv, nsteps) h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5 = seq_to_batch(h5) - pi = fc(h5, 'pi', nact) vf = fc(h5, 'v', 1) - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h5) v0 = vf[:, 0] a0 = self.pd.sample() @@ -92,25 +86,21 @@ class LstmPolicy(object): self.X = X self.M = M self.S = S - self.pi = pi self.vf = vf self.step = step self.value = value class CnnPolicy(object): - def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 + def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613 nh, nw, nc = ob_space.shape ob_shape = (nbatch, nh, nw, nc) - nact = ac_space.n + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.uint8, ob_shape) #obs with tf.variable_scope("model", reuse=reuse): - h = nature_cnn(X) - pi = fc(h, 'pi', nact, init_scale=0.01) + h = nature_cnn(X, **conv_kwargs) vf = fc(h, 'v', 1)[:,0] - - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pi) + self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01) a0 = self.pd.sample() neglogp0 = self.pd.neglogp(a0) @@ -124,7 +114,6 @@ class CnnPolicy(object): return sess.run(vf, {X:ob}) self.X = X - self.pi = pi self.vf = vf self.step = step self.value = value @@ -132,23 +121,19 @@ class CnnPolicy(object): class MlpPolicy(object): def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 ob_shape = (nbatch,) + ob_space.shape - actdim = ac_space.shape[0] + self.pdtype = make_pdtype(ac_space) X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs with tf.variable_scope("model", reuse=reuse): activ = tf.tanh - h1 = activ(fc(X, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) - h2 = activ(fc(h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) - pi = fc(h2, 'pi', actdim, init_scale=0.01) - h1 = activ(fc(X, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) - h2 = activ(fc(h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) - vf = fc(h2, 'vf', 1)[:,0] - logstd = tf.get_variable(name="logstd", shape=[1, actdim], - initializer=tf.zeros_initializer()) + flatten = tf.layers.flatten + pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2))) + pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) + vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2))) + vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) + vf = fc(vf_h2, 'vf', 1)[:,0] - pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1) + self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01) - self.pdtype = make_pdtype(ac_space) - self.pd = self.pdtype.pdfromflat(pdparam) a0 = self.pd.sample() neglogp0 = self.pd.neglogp(a0) @@ -162,7 +147,6 @@ class MlpPolicy(object): return sess.run(vf, {X:ob}) self.X = X - self.pi = pi self.vf = vf self.step = step self.value = value diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index f1bbc79..6acdae0 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -7,6 +7,7 @@ import tensorflow as tf from baselines import logger from collections import deque from baselines.common import explained_variance +from baselines.common.runners import AbstractEnvRunner class Model(object): def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, @@ -84,19 +85,12 @@ class Model(object): self.load = load tf.global_variables_initializer().run(session=sess) #pylint: disable=E1101 -class Runner(object): +class Runner(AbstractEnvRunner): def __init__(self, *, env, model, nsteps, gamma, lam): - self.env = env - self.model = model - nenv = env.num_envs - self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=model.train_model.X.dtype.name) - self.obs[:] = env.reset() - self.gamma = gamma + super().__init__(env=env, model=model, nsteps=nsteps) self.lam = lam - self.nsteps = nsteps - self.states = model.initial_state - self.dones = [False for _ in range(nenv)] + self.gamma = gamma def run(self): mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[] @@ -154,7 +148,7 @@ def constfn(val): def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, - save_interval=0): + save_interval=0, load_path=None): if isinstance(lr, float): lr = constfn(lr) else: assert callable(lr) @@ -176,6 +170,8 @@ def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr, with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh: fh.write(cloudpickle.dumps(make_model)) model = make_model() + if load_path is not None: + model.load(load_path) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) epinfobuf = deque(maxlen=100) diff --git a/baselines/ppo2/run_atari.py b/baselines/ppo2/run_atari.py index 1bfa917..322837a 100644 --- a/baselines/ppo2/run_atari.py +++ b/baselines/ppo2/run_atari.py @@ -4,7 +4,7 @@ from baselines import logger from baselines.common.cmd_util import make_atari_env, atari_arg_parser from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.ppo2 import ppo2 -from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy +from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy, MlpPolicy import multiprocessing import tensorflow as tf @@ -20,7 +20,7 @@ def train(env_id, num_timesteps, seed, policy): tf.Session(config=config).__enter__() env = VecFrameStack(make_atari_env(env_id, 8, seed), 4) - policy = {'cnn' : CnnPolicy, 'lstm' : LstmPolicy, 'lnlstm' : LnLstmPolicy}[policy] + policy = {'cnn' : CnnPolicy, 'lstm' : LstmPolicy, 'lnlstm' : LnLstmPolicy, 'mlp': MlpPolicy}[policy] ppo2.learn(policy=policy, env=env, nsteps=128, nminibatches=4, lam=0.95, gamma=0.99, noptepochs=4, log_interval=1, ent_coef=.01, @@ -30,7 +30,7 @@ def train(env_id, num_timesteps, seed, policy): def main(): parser = atari_arg_parser() - parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn') + parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn') args = parser.parse_args() logger.configure() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,