import internal repo (#385)

This commit is contained in:
pzhokhov
2018-05-01 16:54:04 -07:00
committed by GitHub
parent 2b0283b9db
commit 69f25c6028
21 changed files with 167 additions and 168 deletions

View File

@@ -1,16 +1,12 @@
import os
import os.path as osp import os.path as osp
import gym
import time import time
import joblib import joblib
import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger
from baselines.common import set_global_seeds, explained_variance from baselines.common import set_global_seeds, explained_variance
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.runners import AbstractEnvRunner
from baselines.common.atari_wrappers import wrap_deepmind
from baselines.common import tf_util from baselines.common import tf_util
from baselines.a2c.utils import discount_with_dones from baselines.a2c.utils import discount_with_dones
@@ -24,7 +20,6 @@ class Model(object):
alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'): alpha=0.99, epsilon=1e-5, total_timesteps=int(80e6), lrschedule='linear'):
sess = tf_util.make_session() sess = tf_util.make_session()
nact = ac_space.n
nbatch = nenvs*nsteps nbatch = nenvs*nsteps
A = tf.placeholder(tf.int32, [nbatch]) A = tf.placeholder(tf.int32, [nbatch])
@@ -75,7 +70,7 @@ class Model(object):
restores = [] restores = []
for p, loaded_p in zip(params, loaded_params): for p, loaded_p in zip(params, loaded_params):
restores.append(p.assign(loaded_p)) restores.append(p.assign(loaded_p))
ps = sess.run(restores) sess.run(restores)
self.train = train self.train = train
self.train_model = train_model self.train_model = train_model
@@ -87,21 +82,11 @@ class Model(object):
self.load = load self.load = load
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
class Runner(object): class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps=5, gamma=0.99): def __init__(self, env, model, nsteps=5, gamma=0.99):
self.env = env super().__init__(env=env, model=model, nsteps=nsteps)
self.model = model
nh, nw, nc = env.observation_space.shape
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc)
self.obs = np.zeros((nenv, nh, nw, nc), dtype=np.uint8)
self.nc = nc
obs = env.reset()
self.gamma = gamma self.gamma = gamma
self.nsteps = nsteps
self.states = model.initial_state
self.dones = [False for _ in range(nenv)]
def run(self): def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]

View File

@@ -3,15 +3,16 @@ import tensorflow as tf
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm
from baselines.common.distributions import make_pdtype from baselines.common.distributions import make_pdtype
def nature_cnn(unscaled_images): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
CNN from Nature paper. CNN from Nature paper.
""" """
scaled_images = tf.cast(unscaled_images, tf.float32) / 255. scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
activ = tf.nn.relu activ = tf.nn.relu
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))) h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))) **conv_kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))) h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
h3 = conv_to_fc(h3) h3 = conv_to_fc(h3)
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
@@ -20,21 +21,18 @@ class LnLstmPolicy(object):
nenv = nbatch // nsteps nenv = nbatch // nsteps
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
self.pdtype = make_pdtype(ac_space)
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X) h = nature_cnn(X)
xs = batch_to_seq(h, nenv, nsteps) xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm)
h5 = seq_to_batch(h5) h5 = seq_to_batch(h5)
pi = fc(h5, 'pi', nact)
vf = fc(h5, 'v', 1) vf = fc(h5, 'v', 1)
self.pd, self.pi = self.pdtype.pdfromlatent(h5)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
v0 = vf[:, 0] v0 = vf[:, 0]
a0 = self.pd.sample() a0 = self.pd.sample()
@@ -50,7 +48,6 @@ class LnLstmPolicy(object):
self.X = X self.X = X
self.M = M self.M = M
self.S = S self.S = S
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
@@ -62,7 +59,7 @@ class LstmPolicy(object):
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
@@ -72,11 +69,8 @@ class LstmPolicy(object):
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
h5 = seq_to_batch(h5) h5 = seq_to_batch(h5)
pi = fc(h5, 'pi', nact)
vf = fc(h5, 'v', 1) vf = fc(h5, 'v', 1)
self.pd, self.pi = self.pdtype.pdfromlatent(h5)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
v0 = vf[:, 0] v0 = vf[:, 0]
a0 = self.pd.sample() a0 = self.pd.sample()
@@ -92,25 +86,21 @@ class LstmPolicy(object):
self.X = X self.X = X
self.M = M self.M = M
self.S = S self.S = S
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
class CnnPolicy(object): class CnnPolicy(object):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X) h = nature_cnn(X, **conv_kwargs)
pi = fc(h, 'pi', nact, init_scale=0.01)
vf = fc(h, 'v', 1)[:,0] vf = fc(h, 'v', 1)[:,0]
self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
a0 = self.pd.sample() a0 = self.pd.sample()
neglogp0 = self.pd.neglogp(a0) neglogp0 = self.pd.neglogp(a0)
@@ -124,7 +114,6 @@ class CnnPolicy(object):
return sess.run(vf, {X:ob}) return sess.run(vf, {X:ob})
self.X = X self.X = X
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
@@ -132,23 +121,19 @@ class CnnPolicy(object):
class MlpPolicy(object): class MlpPolicy(object):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613
ob_shape = (nbatch,) + ob_space.shape ob_shape = (nbatch,) + ob_space.shape
actdim = ac_space.shape[0] self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
activ = tf.tanh activ = tf.tanh
h1 = activ(fc(X, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) flatten = tf.layers.flatten
h2 = activ(fc(h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2)))
pi = fc(h2, 'pi', actdim, init_scale=0.01) pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2)))
h1 = activ(fc(X, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2)))
h2 = activ(fc(h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2)))
vf = fc(h2, 'vf', 1)[:,0] vf = fc(vf_h2, 'vf', 1)[:,0]
logstd = tf.get_variable(name="logstd", shape=[1, actdim],
initializer=tf.zeros_initializer())
pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1) self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pdparam)
a0 = self.pd.sample() a0 = self.pd.sample()
neglogp0 = self.pd.neglogp(a0) neglogp0 = self.pd.neglogp(a0)
@@ -162,7 +147,6 @@ class MlpPolicy(object):
return sess.run(vf, {X:ob}) return sess.run(vf, {X:ob})
self.X = X self.X = X
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value

View File

@@ -39,7 +39,7 @@ def ortho_init(scale=1.0):
return (scale * q[:shape[0], :shape[1]]).astype(np.float32) return (scale * q[:shape[0], :shape[1]]).astype(np.float32)
return _ortho_init return _ortho_init
def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC'): def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='NHWC', one_dim_bias=False):
if data_format == 'NHWC': if data_format == 'NHWC':
channel_ax = 3 channel_ax = 3
strides = [1, stride, stride, 1] strides = [1, stride, stride, 1]
@@ -50,12 +50,14 @@ def conv(x, scope, *, nf, rf, stride, pad='VALID', init_scale=1.0, data_format='
bshape = [1, nf, 1, 1] bshape = [1, nf, 1, 1]
else: else:
raise NotImplementedError raise NotImplementedError
bias_var_shape = [nf] if one_dim_bias else [1, nf, 1, 1]
nin = x.get_shape()[channel_ax].value nin = x.get_shape()[channel_ax].value
wshape = [rf, rf, nin, nf] wshape = [rf, rf, nin, nf]
with tf.variable_scope(scope): with tf.variable_scope(scope):
w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale)) w = tf.get_variable("w", wshape, initializer=ortho_init(init_scale))
b = tf.get_variable("b", [1, nf, 1, 1], initializer=tf.constant_initializer(0.0)) b = tf.get_variable("b", bias_var_shape, initializer=tf.constant_initializer(0.0))
if data_format == 'NHWC': b = tf.reshape(b, bshape) if not one_dim_bias and data_format == 'NHWC':
b = tf.reshape(b, bshape)
return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format) return b + tf.nn.conv2d(x, w, strides=strides, padding=pad, data_format=data_format)
def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0): def fc(x, scope, nh, *, init_scale=1.0, init_bias=0.0):

View File

@@ -5,6 +5,7 @@ import tensorflow as tf
from baselines import logger from baselines import logger
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.runners import AbstractEnvRunner
from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import batch_to_seq, seq_to_batch
from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables from baselines.a2c.utils import Scheduler, make_path, find_trainable_variables
@@ -13,6 +14,8 @@ from baselines.a2c.utils import EpisodeStats
from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance from baselines.a2c.utils import get_by_index, check_shape, avg_norm, gradient_add, q_explained_variance
from baselines.acer.buffer import Buffer from baselines.acer.buffer import Buffer
import os.path as osp
# remove last step # remove last step
def strip(var, nenvs, nsteps, flat = False): def strip(var, nenvs, nsteps, flat = False):
vars = batch_to_seq(var, nenvs, nsteps + 1, flat) vars = batch_to_seq(var, nenvs, nsteps + 1, flat)
@@ -209,11 +212,10 @@ class Model(object):
self.initial_state = step_model.initial_state self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
class Runner(object): class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps, nstack): def __init__(self, env, model, nsteps, nstack):
self.env = env super().__init__(env=env, model=model, nsteps=nsteps)
self.nstack = nstack self.nstack = nstack
self.model = model
nh, nw, nc = env.observation_space.shape nh, nw, nc = env.observation_space.shape
self.nc = nc # nc = 1 for atari, but just in case self.nc = nc # nc = 1 for atari, but just in case
self.nenv = nenv = env.num_envs self.nenv = nenv = env.num_envs
@@ -223,9 +225,6 @@ class Runner(object):
self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8) self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8)
obs = env.reset() obs = env.reset()
self.update_obs(obs) self.update_obs(obs)
self.nsteps = nsteps
self.states = model.initial_state
self.dones = [False for _ in range(nenv)]
def update_obs(self, obs, dones=None): def update_obs(self, obs, dones=None):
if dones is not None: if dones is not None:

View File

@@ -1,5 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from functools import partial
from baselines import logger from baselines import logger
from baselines.acktr.acktr_disc import learn from baselines.acktr.acktr_disc import learn
from baselines.common.cmd_util import make_atari_env, atari_arg_parser from baselines.common.cmd_util import make_atari_env, atari_arg_parser
@@ -8,7 +10,7 @@ from baselines.ppo2.policies import CnnPolicy
def train(env_id, num_timesteps, seed, num_cpu): def train(env_id, num_timesteps, seed, num_cpu):
env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4) env = VecFrameStack(make_atari_env(env_id, num_cpu, seed), 4)
policy_fn = CnnPolicy policy_fn = partial(CnnPolicy, one_dim_bias=True)
learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu) learn(policy_fn, env, seed, total_timesteps=int(num_timesteps * 1.1), nprocs=num_cpu)
env.close() env.close()

View File

@@ -1,3 +1,4 @@
# flake8: noqa F403
from baselines.common.console_util import * from baselines.common.console_util import *
from baselines.common.dataset import Dataset from baselines.common.dataset import Dataset
from baselines.common.math_util import * from baselines.common.math_util import *

View File

@@ -98,9 +98,6 @@ class MaxAndSkipEnv(gym.Wrapper):
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8) self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
self._skip = skip self._skip = skip
def reset(self):
return self.env.reset()
def step(self, action): def step(self, action):
"""Repeat action, sum reward, and max over last observations.""" """Repeat action, sum reward, and max over last observations."""
total_reward = 0.0 total_reward = 0.0

View File

@@ -10,7 +10,6 @@ from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.atari_wrappers import make_atari, wrap_deepmind from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from mpi4py import MPI
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0): def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
""" """

View File

@@ -1,6 +1,7 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import baselines.common.tf_util as U import baselines.common.tf_util as U
from baselines.a2c.utils import fc
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
class Pd(object): class Pd(object):
@@ -31,6 +32,8 @@ class PdType(object):
raise NotImplementedError raise NotImplementedError
def pdfromflat(self, flat): def pdfromflat(self, flat):
return self.pdclass()(flat) return self.pdclass()(flat)
def pdfromlatent(self, latent_vector):
raise NotImplementedError
def param_shape(self): def param_shape(self):
raise NotImplementedError raise NotImplementedError
def sample_shape(self): def sample_shape(self):
@@ -48,6 +51,10 @@ class CategoricalPdType(PdType):
self.ncat = ncat self.ncat = ncat
def pdclass(self): def pdclass(self):
return CategoricalPd return CategoricalPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
pdparam = fc(latent_vector, 'pi', self.ncat, init_scale=init_scale, init_bias=init_bias)
return self.pdfromflat(pdparam), pdparam
def param_shape(self): def param_shape(self):
return [self.ncat] return [self.ncat]
def sample_shape(self): def sample_shape(self):
@@ -75,6 +82,13 @@ class DiagGaussianPdType(PdType):
self.size = size self.size = size
def pdclass(self): def pdclass(self):
return DiagGaussianPd return DiagGaussianPd
def pdfromlatent(self, latent_vector, init_scale=1.0, init_bias=0.0):
mean = fc(latent_vector, 'pi', self.size, init_scale=init_scale, init_bias=init_bias)
logstd = tf.get_variable(name='logstd', shape=[1, self.size], initializer=tf.zeros_initializer())
pdparam = tf.concat([mean, mean * 0.0 + logstd], axis=1)
return self.pdfromflat(pdparam), mean
def param_shape(self): def param_shape(self):
return [2*self.size] return [2*self.size]
def sample_shape(self): def sample_shape(self):

View File

@@ -2,6 +2,7 @@ from mpi4py import MPI
import numpy as np import numpy as np
from baselines.common import zipsame from baselines.common import zipsame
def mpi_mean(x, axis=0, comm=None, keepdims=False): def mpi_mean(x, axis=0, comm=None, keepdims=False):
x = np.asarray(x) x = np.asarray(x)
assert x.ndim > 0 assert x.ndim > 0

View File

@@ -12,10 +12,9 @@ class SegmentTree(object):
a) setting item's value is slightly slower. a) setting item's value is slightly slower.
It is O(lg capacity) instead of O(1). It is O(lg capacity) instead of O(1).
b) user has access to an efficient `reduce` b) user has access to an efficient ( O(log segment size) )
operation which reduces `operation` over `reduce` operation which reduces `operation` over
a contiguous subsequence of items in the a contiguous subsequence of items in the array.
array.
Paramters Paramters
--------- ---------
@@ -23,8 +22,8 @@ class SegmentTree(object):
Total size of the array - must be a power of two. Total size of the array - must be a power of two.
operation: lambda obj, obj -> obj operation: lambda obj, obj -> obj
and operation for combining elements (eg. sum, max) and operation for combining elements (eg. sum, max)
must for a mathematical group together with the set of must form a mathematical group together with the set of
possible values for array elements. possible values for array elements (i.e. be associative)
neutral_element: obj neutral_element: obj
neutral element for the operation above. eg. float('-inf') neutral element for the operation above. eg. float('-inf')
for max and 0 for sum. for max and 0 for sum.

View File

@@ -33,7 +33,6 @@ def test_multikwargs():
initialize() initialize()
assert lin(2) == 6 assert lin(2) == 6
assert lin(2, 2) == 10 assert lin(2, 2) == 10
expt_caught = False
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -48,7 +48,7 @@ def huber_loss(x, delta=1.0):
# Global session # Global session
# ================================================================ # ================================================================
def make_session(num_cpu=None, make_default=False): def make_session(num_cpu=None, make_default=False, graph=None):
"""Returns a session that will use <num_cpu> CPU's only""" """Returns a session that will use <num_cpu> CPU's only"""
if num_cpu is None: if num_cpu is None:
num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count())) num_cpu = int(os.getenv('RCALL_NUM_CPU', multiprocessing.cpu_count()))
@@ -57,9 +57,9 @@ def make_session(num_cpu=None, make_default=False):
intra_op_parallelism_threads=num_cpu) intra_op_parallelism_threads=num_cpu)
tf_config.gpu_options.allocator_type = 'BFC' tf_config.gpu_options.allocator_type = 'BFC'
if make_default: if make_default:
return tf.InteractiveSession(config=tf_config) return tf.InteractiveSession(config=tf_config, graph=graph)
else: else:
return tf.Session(config=tf_config) return tf.Session(config=tf_config, graph=graph)
def single_threaded_session(): def single_threaded_session():
"""Returns a session which will only use a single CPU""" """Returns a session which will only use a single CPU"""
@@ -84,10 +84,10 @@ def initialize():
# Model components # Model components
# ================================================================ # ================================================================
def normc_initializer(std=1.0): def normc_initializer(std=1.0, axis=0):
def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613 def _initializer(shape, dtype=None, partition_info=None): # pylint: disable=W0613
out = np.random.randn(*shape).astype(np.float32) out = np.random.randn(*shape).astype(np.float32)
out *= std / np.sqrt(np.square(out).sum(axis=0, keepdims=True)) out *= std / np.sqrt(np.square(out).sum(axis=axis, keepdims=True))
return tf.constant(out) return tf.constant(out)
return _initializer return _initializer
@@ -273,8 +273,9 @@ def display_var_info(vars):
for v in vars: for v in vars:
name = v.name name = v.name
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
count_params += np.prod(v.shape.as_list()) v_params = np.prod(v.shape.as_list())
if "/b:" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print count_params += v_params
logger.info(" %s%s%s" % (name, " "*(55-len(name)), str(v.shape))) if "/b:" in name or "/biases" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
logger.info("Total model parameters: %0.1f million" % (count_params*1e-6)) logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))

View File

@@ -80,6 +80,13 @@ class VecEnv(ABC):
def render(self): def render(self):
logger.warn('Render not defined for %s'%self) logger.warn('Render not defined for %s'%self)
@property
def unwrapped(self):
if isinstance(self, VecEnvWrapper):
return self.venv.unwrapped
else:
return self
class VecEnvWrapper(VecEnv): class VecEnvWrapper(VecEnv):
def __init__(self, venv, observation_space=None, action_space=None): def __init__(self, venv, observation_space=None, action_space=None):
self.venv = venv self.venv = venv

View File

@@ -1,5 +1,6 @@
import numpy as np import numpy as np
import gym from gym import spaces
from collections import OrderedDict
from . import VecEnv from . import VecEnv
class DummyVecEnv(VecEnv): class DummyVecEnv(VecEnv):
@@ -7,9 +8,22 @@ class DummyVecEnv(VecEnv):
self.envs = [fn() for fn in env_fns] self.envs = [fn() for fn in env_fns]
env = self.envs[0] env = self.envs[0]
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
shapes, dtypes = {}, {}
obs_spaces = self.observation_space.spaces if isinstance(self.observation_space, gym.spaces.Tuple) else (self.observation_space,) self.keys = []
self.buf_obs = [np.zeros((self.num_envs,) + tuple(s.shape), s.dtype) for s in obs_spaces] obs_space = env.observation_space
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict)
for key, box in obs_space.spaces.items():
assert isinstance(box, spaces.Box)
shapes[key] = box.shape
dtypes[key] = box.dtype
self.keys.append(key)
else:
box = obs_space
assert isinstance(box, spaces.Box)
self.keys = [None]
shapes, dtypes = { None: box.shape }, { None: box.dtype }
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool) self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32) self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
self.buf_infos = [{} for _ in range(self.num_envs)] self.buf_infos = [{} for _ in range(self.num_envs)]
@@ -19,33 +33,32 @@ class DummyVecEnv(VecEnv):
self.actions = actions self.actions = actions
def step_wait(self): def step_wait(self):
for i in range(self.num_envs): for e in range(self.num_envs):
obs_tuple, self.buf_rews[i], self.buf_dones[i], self.buf_infos[i] = self.envs[i].step(self.actions[i]) obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(self.actions[e])
if self.buf_dones[i]: if self.buf_dones[e]:
obs_tuple = self.envs[i].reset() obs = self.envs[e].reset()
if isinstance(obs_tuple, (tuple, list)): self._save_obs(e, obs)
for t,x in enumerate(obs_tuple):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
self.buf_infos.copy()) self.buf_infos.copy())
def reset(self): def reset(self):
for i in range(self.num_envs): for e in range(self.num_envs):
obs_tuple = self.envs[i].reset() obs = self.envs[e].reset()
if isinstance(obs_tuple, (tuple, list)): self._save_obs(e, obs)
for t,x in enumerate(obs_tuple):
self.buf_obs[t][i] = x
else:
self.buf_obs[0][i] = obs_tuple
return self._obs_from_buf() return self._obs_from_buf()
def close(self): def close(self):
return return
def _obs_from_buf(self): def _save_obs(self, e, obs):
if len(self.buf_obs) == 1: for k in self.keys:
return np.copy(self.buf_obs[0]) if k is None:
self.buf_obs[k][e] = obs
else: else:
return tuple(np.copy(x) for x in self.buf_obs) self.buf_obs[k][e] = obs[k]
def _obs_from_buf(self):
if self.keys==[None]:
return self.buf_obs[None]
else:
return self.buf_obs

View File

@@ -5,11 +5,13 @@ import argparse
from baselines import logger from baselines import logger
from baselines.common.atari_wrappers import make_atari from baselines.common.atari_wrappers import make_atari
def main(): def main():
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4') parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--prioritized', type=int, default=1) parser.add_argument('--prioritized', type=int, default=1)
parser.add_argument('--prioritized-replay-alpha', type=float, default=0.6)
parser.add_argument('--dueling', type=int, default=1) parser.add_argument('--dueling', type=int, default=1)
parser.add_argument('--num-timesteps', type=int, default=int(10e6)) parser.add_argument('--num-timesteps', type=int, default=int(10e6))
args = parser.parse_args() args = parser.parse_args()
@@ -23,7 +25,8 @@ def main():
hiddens=[256], hiddens=[256],
dueling=bool(args.dueling), dueling=bool(args.dueling),
) )
act = deepq.learn(
deepq.learn(
env, env,
q_func=model, q_func=model,
lr=1e-4, lr=1e-4,
@@ -35,9 +38,10 @@ def main():
learning_starts=10000, learning_starts=10000,
target_network_update_freq=1000, target_network_update_freq=1000,
gamma=0.99, gamma=0.99,
prioritized_replay=bool(args.prioritized) prioritized_replay=bool(args.prioritized),
prioritized_replay_alpha=args.prioritized_replay_alpha
) )
# act.save("pong_model.pkl") XXX
env.close() env.close()

View File

@@ -86,7 +86,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
ReplayBuffer.__init__ ReplayBuffer.__init__
""" """
super(PrioritizedReplayBuffer, self).__init__(size) super(PrioritizedReplayBuffer, self).__init__(size)
assert alpha > 0 assert alpha >= 0
self._alpha = alpha self._alpha = alpha
it_capacity = 1 it_capacity = 1

View File

@@ -9,6 +9,7 @@ import tempfile
from collections import defaultdict from collections import defaultdict
LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv'] LOG_OUTPUT_FORMATS = ['stdout', 'log', 'csv']
LOG_OUTPUT_FORMATS_MPI = ['log']
# Also valid: json, tensorboard # Also valid: json, tensorboard
DEBUG = 10 DEBUG = 10
@@ -355,10 +356,21 @@ def configure(dir=None, format_strs=None):
assert isinstance(dir, str) assert isinstance(dir, str)
os.makedirs(dir, exist_ok=True) os.makedirs(dir, exist_ok=True)
log_suffix = ''
from mpi4py import MPI
rank = MPI.COMM_WORLD.Get_rank()
if rank > 0:
log_suffix = "-rank%03i" % rank
if format_strs is None: if format_strs is None:
strs = os.getenv('OPENAI_LOG_FORMAT') strs, strs_mpi = os.getenv('OPENAI_LOG_FORMAT'), os.getenv('OPENAI_LOG_FORMAT_MPI')
format_strs = strs.split(',') if strs else LOG_OUTPUT_FORMATS format_strs = strs_mpi if rank>0 else strs
output_formats = [make_output_format(f, dir) for f in format_strs] if format_strs is not None:
format_strs = format_strs.split(',')
else:
format_strs = LOG_OUTPUT_FORMATS_MPI if rank>0 else LOG_OUTPUT_FORMATS
output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) Logger.CURRENT = Logger(dir=dir, output_formats=output_formats)
log('Logging to %s'%dir) log('Logging to %s'%dir)

View File

@@ -3,15 +3,16 @@ import tensorflow as tf
from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm from baselines.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm
from baselines.common.distributions import make_pdtype from baselines.common.distributions import make_pdtype
def nature_cnn(unscaled_images): def nature_cnn(unscaled_images, **conv_kwargs):
""" """
CNN from Nature paper. CNN from Nature paper.
""" """
scaled_images = tf.cast(unscaled_images, tf.float32) / 255. scaled_images = tf.cast(unscaled_images, tf.float32) / 255.
activ = tf.nn.relu activ = tf.nn.relu
h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2))) h = activ(conv(scaled_images, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2),
h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2))) **conv_kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2))) h2 = activ(conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), **conv_kwargs))
h3 = activ(conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2), **conv_kwargs))
h3 = conv_to_fc(h3) h3 = conv_to_fc(h3)
return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2))) return activ(fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)))
@@ -20,21 +21,18 @@ class LnLstmPolicy(object):
nenv = nbatch // nsteps nenv = nbatch // nsteps
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
self.pdtype = make_pdtype(ac_space)
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X) h = nature_cnn(X)
xs = batch_to_seq(h, nenv, nsteps) xs = batch_to_seq(h, nenv, nsteps)
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm)
h5 = seq_to_batch(h5) h5 = seq_to_batch(h5)
pi = fc(h5, 'pi', nact)
vf = fc(h5, 'v', 1) vf = fc(h5, 'v', 1)
self.pd, self.pi = self.pdtype.pdfromlatent(h5)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
v0 = vf[:, 0] v0 = vf[:, 0]
a0 = self.pd.sample() a0 = self.pd.sample()
@@ -50,7 +48,6 @@ class LnLstmPolicy(object):
self.X = X self.X = X
self.M = M self.M = M
self.S = S self.S = S
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
@@ -62,7 +59,7 @@ class LstmPolicy(object):
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1)
S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states
@@ -72,11 +69,8 @@ class LstmPolicy(object):
ms = batch_to_seq(M, nenv, nsteps) ms = batch_to_seq(M, nenv, nsteps)
h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm)
h5 = seq_to_batch(h5) h5 = seq_to_batch(h5)
pi = fc(h5, 'pi', nact)
vf = fc(h5, 'v', 1) vf = fc(h5, 'v', 1)
self.pd, self.pi = self.pdtype.pdfromlatent(h5)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
v0 = vf[:, 0] v0 = vf[:, 0]
a0 = self.pd.sample() a0 = self.pd.sample()
@@ -92,25 +86,21 @@ class LstmPolicy(object):
self.X = X self.X = X
self.M = M self.M = M
self.S = S self.S = S
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
class CnnPolicy(object): class CnnPolicy(object):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False, **conv_kwargs): #pylint: disable=W0613
nh, nw, nc = ob_space.shape nh, nw, nc = ob_space.shape
ob_shape = (nbatch, nh, nw, nc) ob_shape = (nbatch, nh, nw, nc)
nact = ac_space.n self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.uint8, ob_shape) #obs X = tf.placeholder(tf.uint8, ob_shape) #obs
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
h = nature_cnn(X) h = nature_cnn(X, **conv_kwargs)
pi = fc(h, 'pi', nact, init_scale=0.01)
vf = fc(h, 'v', 1)[:,0] vf = fc(h, 'v', 1)[:,0]
self.pd, self.pi = self.pdtype.pdfromlatent(h, init_scale=0.01)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pi)
a0 = self.pd.sample() a0 = self.pd.sample()
neglogp0 = self.pd.neglogp(a0) neglogp0 = self.pd.neglogp(a0)
@@ -124,7 +114,6 @@ class CnnPolicy(object):
return sess.run(vf, {X:ob}) return sess.run(vf, {X:ob})
self.X = X self.X = X
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value
@@ -132,23 +121,19 @@ class CnnPolicy(object):
class MlpPolicy(object): class MlpPolicy(object):
def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613 def __init__(self, sess, ob_space, ac_space, nbatch, nsteps, reuse=False): #pylint: disable=W0613
ob_shape = (nbatch,) + ob_space.shape ob_shape = (nbatch,) + ob_space.shape
actdim = ac_space.shape[0] self.pdtype = make_pdtype(ac_space)
X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs X = tf.placeholder(tf.float32, ob_shape, name='Ob') #obs
with tf.variable_scope("model", reuse=reuse): with tf.variable_scope("model", reuse=reuse):
activ = tf.tanh activ = tf.tanh
h1 = activ(fc(X, 'pi_fc1', nh=64, init_scale=np.sqrt(2))) flatten = tf.layers.flatten
h2 = activ(fc(h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2))) pi_h1 = activ(fc(flatten(X), 'pi_fc1', nh=64, init_scale=np.sqrt(2)))
pi = fc(h2, 'pi', actdim, init_scale=0.01) pi_h2 = activ(fc(pi_h1, 'pi_fc2', nh=64, init_scale=np.sqrt(2)))
h1 = activ(fc(X, 'vf_fc1', nh=64, init_scale=np.sqrt(2))) vf_h1 = activ(fc(flatten(X), 'vf_fc1', nh=64, init_scale=np.sqrt(2)))
h2 = activ(fc(h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2))) vf_h2 = activ(fc(vf_h1, 'vf_fc2', nh=64, init_scale=np.sqrt(2)))
vf = fc(h2, 'vf', 1)[:,0] vf = fc(vf_h2, 'vf', 1)[:,0]
logstd = tf.get_variable(name="logstd", shape=[1, actdim],
initializer=tf.zeros_initializer())
pdparam = tf.concat([pi, pi * 0.0 + logstd], axis=1) self.pd, self.pi = self.pdtype.pdfromlatent(pi_h2, init_scale=0.01)
self.pdtype = make_pdtype(ac_space)
self.pd = self.pdtype.pdfromflat(pdparam)
a0 = self.pd.sample() a0 = self.pd.sample()
neglogp0 = self.pd.neglogp(a0) neglogp0 = self.pd.neglogp(a0)
@@ -162,7 +147,6 @@ class MlpPolicy(object):
return sess.run(vf, {X:ob}) return sess.run(vf, {X:ob})
self.X = X self.X = X
self.pi = pi
self.vf = vf self.vf = vf
self.step = step self.step = step
self.value = value self.value = value

View File

@@ -7,6 +7,7 @@ import tensorflow as tf
from baselines import logger from baselines import logger
from collections import deque from collections import deque
from baselines.common import explained_variance from baselines.common import explained_variance
from baselines.common.runners import AbstractEnvRunner
class Model(object): class Model(object):
def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train, def __init__(self, *, policy, ob_space, ac_space, nbatch_act, nbatch_train,
@@ -84,19 +85,12 @@ class Model(object):
self.load = load self.load = load
tf.global_variables_initializer().run(session=sess) #pylint: disable=E1101 tf.global_variables_initializer().run(session=sess) #pylint: disable=E1101
class Runner(object): class Runner(AbstractEnvRunner):
def __init__(self, *, env, model, nsteps, gamma, lam): def __init__(self, *, env, model, nsteps, gamma, lam):
self.env = env super().__init__(env=env, model=model, nsteps=nsteps)
self.model = model
nenv = env.num_envs
self.obs = np.zeros((nenv,) + env.observation_space.shape, dtype=model.train_model.X.dtype.name)
self.obs[:] = env.reset()
self.gamma = gamma
self.lam = lam self.lam = lam
self.nsteps = nsteps self.gamma = gamma
self.states = model.initial_state
self.dones = [False for _ in range(nenv)]
def run(self): def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[] mb_obs, mb_rewards, mb_actions, mb_values, mb_dones, mb_neglogpacs = [],[],[],[],[],[]
@@ -154,7 +148,7 @@ def constfn(val):
def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr, def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,
save_interval=0): save_interval=0, load_path=None):
if isinstance(lr, float): lr = constfn(lr) if isinstance(lr, float): lr = constfn(lr)
else: assert callable(lr) else: assert callable(lr)
@@ -176,6 +170,8 @@ def learn(*, policy, env, nsteps, total_timesteps, ent_coef, lr,
with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh: with open(osp.join(logger.get_dir(), 'make_model.pkl'), 'wb') as fh:
fh.write(cloudpickle.dumps(make_model)) fh.write(cloudpickle.dumps(make_model))
model = make_model() model = make_model()
if load_path is not None:
model.load(load_path)
runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam) runner = Runner(env=env, model=model, nsteps=nsteps, gamma=gamma, lam=lam)
epinfobuf = deque(maxlen=100) epinfobuf = deque(maxlen=100)

View File

@@ -4,7 +4,7 @@ from baselines import logger
from baselines.common.cmd_util import make_atari_env, atari_arg_parser from baselines.common.cmd_util import make_atari_env, atari_arg_parser
from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.ppo2 import ppo2 from baselines.ppo2 import ppo2
from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy from baselines.ppo2.policies import CnnPolicy, LstmPolicy, LnLstmPolicy, MlpPolicy
import multiprocessing import multiprocessing
import tensorflow as tf import tensorflow as tf
@@ -20,7 +20,7 @@ def train(env_id, num_timesteps, seed, policy):
tf.Session(config=config).__enter__() tf.Session(config=config).__enter__()
env = VecFrameStack(make_atari_env(env_id, 8, seed), 4) env = VecFrameStack(make_atari_env(env_id, 8, seed), 4)
policy = {'cnn' : CnnPolicy, 'lstm' : LstmPolicy, 'lnlstm' : LnLstmPolicy}[policy] policy = {'cnn' : CnnPolicy, 'lstm' : LstmPolicy, 'lnlstm' : LnLstmPolicy, 'mlp': MlpPolicy}[policy]
ppo2.learn(policy=policy, env=env, nsteps=128, nminibatches=4, ppo2.learn(policy=policy, env=env, nsteps=128, nminibatches=4,
lam=0.95, gamma=0.99, noptepochs=4, log_interval=1, lam=0.95, gamma=0.99, noptepochs=4, log_interval=1,
ent_coef=.01, ent_coef=.01,
@@ -30,7 +30,7 @@ def train(env_id, num_timesteps, seed, policy):
def main(): def main():
parser = atari_arg_parser() parser = atari_arg_parser()
parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm'], default='cnn') parser.add_argument('--policy', help='Policy architecture', choices=['cnn', 'lstm', 'lnlstm', 'mlp'], default='cnn')
args = parser.parse_args() args = parser.parse_args()
logger.configure() logger.configure()
train(args.env, num_timesteps=args.num_timesteps, seed=args.seed, train(args.env, num_timesteps=args.num_timesteps, seed=args.seed,