refactor ACER (#664)
* make acer use vecframestack * acer passes mnist test with 20k steps * acer with non-image observations and tests * flake8 * test acer serialization with non-recurrent policies
This commit is contained in:
@@ -7,6 +7,7 @@ from baselines import logger
|
|||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
from baselines.common.policies import build_policy
|
from baselines.common.policies import build_policy
|
||||||
from baselines.common.tf_util import get_session, save_variables
|
from baselines.common.tf_util import get_session, save_variables
|
||||||
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
|
|
||||||
from baselines.a2c.utils import batch_to_seq, seq_to_batch
|
from baselines.a2c.utils import batch_to_seq, seq_to_batch
|
||||||
from baselines.a2c.utils import cat_entropy_softmax
|
from baselines.a2c.utils import cat_entropy_softmax
|
||||||
@@ -55,8 +56,7 @@ def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma):
|
|||||||
# return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
|
# return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, nstack, num_procs,
|
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr,
|
||||||
ent_coef, q_coef, gamma, max_grad_norm, lr,
|
|
||||||
rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
|
rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
|
||||||
c, trust_region, alpha, delta):
|
c, trust_region, alpha, delta):
|
||||||
|
|
||||||
@@ -71,8 +71,8 @@ class Model(object):
|
|||||||
LR = tf.placeholder(tf.float32, [])
|
LR = tf.placeholder(tf.float32, [])
|
||||||
eps = 1e-6
|
eps = 1e-6
|
||||||
|
|
||||||
step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,))
|
step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape)
|
||||||
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,))
|
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
|
||||||
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
|
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
|
||||||
|
|
||||||
step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess)
|
step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess)
|
||||||
@@ -247,6 +247,7 @@ class Acer():
|
|||||||
# get obs, actions, rewards, mus, dones from buffer.
|
# get obs, actions, rewards, mus, dones from buffer.
|
||||||
obs, actions, rewards, mus, dones, masks = buffer.get()
|
obs, actions, rewards, mus, dones, masks = buffer.get()
|
||||||
|
|
||||||
|
|
||||||
# reshape stuff correctly
|
# reshape stuff correctly
|
||||||
obs = obs.reshape(runner.batch_ob_shape)
|
obs = obs.reshape(runner.batch_ob_shape)
|
||||||
actions = actions.reshape([runner.nbatch])
|
actions = actions.reshape([runner.nbatch])
|
||||||
@@ -270,7 +271,7 @@ class Acer():
|
|||||||
logger.dump_tabular()
|
logger.dump_tabular()
|
||||||
|
|
||||||
|
|
||||||
def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
|
def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
|
||||||
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
|
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
|
||||||
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
|
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
|
||||||
trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
|
trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
|
||||||
@@ -342,21 +343,24 @@ def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6
|
|||||||
print("Running Acer Simple")
|
print("Running Acer Simple")
|
||||||
print(locals())
|
print(locals())
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
policy = build_policy(env, network, estimate_q=True, **network_kwargs)
|
if not isinstance(env, VecFrameStack):
|
||||||
|
env = VecFrameStack(env, 1)
|
||||||
|
|
||||||
|
policy = build_policy(env, network, estimate_q=True, **network_kwargs)
|
||||||
nenvs = env.num_envs
|
nenvs = env.num_envs
|
||||||
ob_space = env.observation_space
|
ob_space = env.observation_space
|
||||||
ac_space = env.action_space
|
ac_space = env.action_space
|
||||||
num_procs = len(env.remotes) if hasattr(env, 'remotes') else 1# HACK
|
|
||||||
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack,
|
nstack = env.nstack
|
||||||
num_procs=num_procs, ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
|
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps,
|
||||||
|
ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
|
||||||
max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
|
max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
|
||||||
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
|
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
|
||||||
trust_region=trust_region, alpha=alpha, delta=delta)
|
trust_region=trust_region, alpha=alpha, delta=delta)
|
||||||
|
|
||||||
runner = Runner(env=env, model=model, nsteps=nsteps, nstack=nstack)
|
runner = Runner(env=env, model=model, nsteps=nsteps)
|
||||||
if replay_ratio > 0:
|
if replay_ratio > 0:
|
||||||
buffer = Buffer(env=env, nsteps=nsteps, nstack=nstack, size=buffer_size)
|
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
|
||||||
else:
|
else:
|
||||||
buffer = None
|
buffer = None
|
||||||
nbatch = nenvs*nsteps
|
nbatch = nenvs*nsteps
|
||||||
|
@@ -2,11 +2,16 @@ import numpy as np
|
|||||||
|
|
||||||
class Buffer(object):
|
class Buffer(object):
|
||||||
# gets obs, actions, rewards, mu's, (states, masks), dones
|
# gets obs, actions, rewards, mu's, (states, masks), dones
|
||||||
def __init__(self, env, nsteps, nstack, size=50000):
|
def __init__(self, env, nsteps, size=50000):
|
||||||
self.nenv = env.num_envs
|
self.nenv = env.num_envs
|
||||||
self.nsteps = nsteps
|
self.nsteps = nsteps
|
||||||
self.nh, self.nw, self.nc = env.observation_space.shape
|
# self.nh, self.nw, self.nc = env.observation_space.shape
|
||||||
self.nstack = nstack
|
self.obs_shape = env.observation_space.shape
|
||||||
|
self.obs_dtype = env.observation_space.dtype
|
||||||
|
self.ac_dtype = env.action_space.dtype
|
||||||
|
self.nc = self.obs_shape[-1]
|
||||||
|
self.nstack = env.nstack
|
||||||
|
self.nc //= self.nstack
|
||||||
self.nbatch = self.nenv * self.nsteps
|
self.nbatch = self.nenv * self.nsteps
|
||||||
self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
|
self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
|
||||||
|
|
||||||
@@ -33,22 +38,11 @@ class Buffer(object):
|
|||||||
# Generate stacked frames
|
# Generate stacked frames
|
||||||
def decode(self, enc_obs, dones):
|
def decode(self, enc_obs, dones):
|
||||||
# enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
|
# enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
|
||||||
# dones has shape [nenvs, nsteps, nh, nw, nc]
|
# dones has shape [nenvs, nsteps]
|
||||||
# returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
|
# returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
|
||||||
nstack, nenv, nsteps, nh, nw, nc = self.nstack, self.nenv, self.nsteps, self.nh, self.nw, self.nc
|
|
||||||
y = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
|
return _stack_obs(enc_obs, dones,
|
||||||
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=np.uint8)
|
nsteps=self.nsteps)
|
||||||
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1,
|
|
||||||
0) # [nsteps + nstack, nenv, nh, nw, nc]
|
|
||||||
y[3:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
|
|
||||||
y[:3] = 1.0
|
|
||||||
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
|
|
||||||
for i in range(nstack):
|
|
||||||
obs[-(i + 1), i:] = x
|
|
||||||
# obs[:,i:,:,:,-(i+1),:] = x
|
|
||||||
x = x[:-1] * y
|
|
||||||
y = y[1:]
|
|
||||||
return np.reshape(obs[:, 3:].transpose((2, 1, 3, 4, 0, 5)), [nenv, (nsteps + 1), nh, nw, nstack * nc])
|
|
||||||
|
|
||||||
def put(self, enc_obs, actions, rewards, mus, dones, masks):
|
def put(self, enc_obs, actions, rewards, mus, dones, masks):
|
||||||
# enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
|
# enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
|
||||||
@@ -56,8 +50,8 @@ class Buffer(object):
|
|||||||
# mus [nenv, nsteps, nact]
|
# mus [nenv, nsteps, nact]
|
||||||
|
|
||||||
if self.enc_obs is None:
|
if self.enc_obs is None:
|
||||||
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=np.uint8)
|
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
|
||||||
self.actions = np.empty([self.size] + list(actions.shape), dtype=np.int32)
|
self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype)
|
||||||
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
|
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
|
||||||
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
|
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
|
||||||
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
|
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
|
||||||
@@ -101,3 +95,62 @@ class Buffer(object):
|
|||||||
mus = take(self.mus)
|
mus = take(self.mus)
|
||||||
masks = take(self.masks)
|
masks = take(self.masks)
|
||||||
return obs, actions, rewards, mus, dones, masks
|
return obs, actions, rewards, mus, dones, masks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _stack_obs_ref(enc_obs, dones, nsteps):
|
||||||
|
nenv = enc_obs.shape[0]
|
||||||
|
nstack = enc_obs.shape[1] - nsteps
|
||||||
|
nh, nw, nc = enc_obs.shape[2:]
|
||||||
|
obs_dtype = enc_obs.dtype
|
||||||
|
obs_shape = (nh, nw, nc*nstack)
|
||||||
|
|
||||||
|
mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
|
||||||
|
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype)
|
||||||
|
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc]
|
||||||
|
|
||||||
|
mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
|
||||||
|
mask[:nstack-1] = 1.0
|
||||||
|
|
||||||
|
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
|
||||||
|
for i in range(nstack):
|
||||||
|
obs[-(i + 1), i:] = x
|
||||||
|
# obs[:,i:,:,:,-(i+1),:] = x
|
||||||
|
x = x[:-1] * mask
|
||||||
|
mask = mask[1:]
|
||||||
|
|
||||||
|
return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape)
|
||||||
|
|
||||||
|
def _stack_obs(enc_obs, dones, nsteps):
|
||||||
|
nenv = enc_obs.shape[0]
|
||||||
|
nstack = enc_obs.shape[1] - nsteps
|
||||||
|
nc = enc_obs.shape[-1]
|
||||||
|
|
||||||
|
obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype)
|
||||||
|
mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype)
|
||||||
|
mask[:, 1:] = 1.0 - dones
|
||||||
|
mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8)))
|
||||||
|
|
||||||
|
for i in range(nstack-1, -1, -1):
|
||||||
|
obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :]
|
||||||
|
if i < nstack-1:
|
||||||
|
obs_[..., i * nc : (i + 1) * nc] *= mask
|
||||||
|
mask[:, 1:, ...] *= mask[:, :-1, ...]
|
||||||
|
|
||||||
|
return obs_
|
||||||
|
|
||||||
|
def test_stack_obs():
|
||||||
|
nstack = 7
|
||||||
|
nenv = 1
|
||||||
|
nsteps = 5
|
||||||
|
|
||||||
|
obs_shape = (2, 3, nstack)
|
||||||
|
|
||||||
|
enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,)
|
||||||
|
enc_obs = np.random.random(enc_obs_shape)
|
||||||
|
dones = np.random.randint(low=0, high=2, size=(nenv, nsteps))
|
||||||
|
|
||||||
|
stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps)
|
||||||
|
stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps)
|
||||||
|
|
||||||
|
np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test)
|
||||||
|
@@ -1,30 +1,31 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from baselines.common.runners import AbstractEnvRunner
|
from baselines.common.runners import AbstractEnvRunner
|
||||||
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
|
||||||
class Runner(AbstractEnvRunner):
|
class Runner(AbstractEnvRunner):
|
||||||
|
|
||||||
def __init__(self, env, model, nsteps, nstack):
|
def __init__(self, env, model, nsteps):
|
||||||
super().__init__(env=env, model=model, nsteps=nsteps)
|
super().__init__(env=env, model=model, nsteps=nsteps)
|
||||||
self.nstack = nstack
|
assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
|
||||||
nh, nw, nc = env.observation_space.shape
|
assert isinstance(env, VecFrameStack)
|
||||||
self.nc = nc # nc = 1 for atari, but just in case
|
|
||||||
self.nact = env.action_space.n
|
self.nact = env.action_space.n
|
||||||
nenv = self.nenv
|
nenv = self.nenv
|
||||||
self.nbatch = nenv * nsteps
|
self.nbatch = nenv * nsteps
|
||||||
self.batch_ob_shape = (nenv*(nsteps+1), nh, nw, nc*nstack)
|
self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape
|
||||||
self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8)
|
|
||||||
obs = env.reset()
|
self.obs = env.reset()
|
||||||
self.update_obs(obs)
|
self.obs_dtype = env.observation_space.dtype
|
||||||
|
self.ac_dtype = env.action_space.dtype
|
||||||
|
self.nstack = self.env.nstack
|
||||||
|
self.nc = self.batch_ob_shape[-1] // self.nstack
|
||||||
|
|
||||||
def update_obs(self, obs, dones=None):
|
|
||||||
#self.obs = obs
|
|
||||||
if dones is not None:
|
|
||||||
self.obs *= (1 - dones.astype(np.uint8))[:, None, None, None]
|
|
||||||
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
|
|
||||||
self.obs[:, :, :, -self.nc:] = obs[:, :, :, :]
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
|
# enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
|
||||||
|
enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
|
||||||
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
|
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
|
||||||
for _ in range(self.nsteps):
|
for _ in range(self.nsteps):
|
||||||
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
|
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
|
||||||
@@ -36,15 +37,15 @@ class Runner(AbstractEnvRunner):
|
|||||||
# states information for statefull models like LSTM
|
# states information for statefull models like LSTM
|
||||||
self.states = states
|
self.states = states
|
||||||
self.dones = dones
|
self.dones = dones
|
||||||
self.update_obs(obs, dones)
|
self.obs = obs
|
||||||
mb_rewards.append(rewards)
|
mb_rewards.append(rewards)
|
||||||
enc_obs.append(obs)
|
enc_obs.append(obs[..., -self.nc:])
|
||||||
mb_obs.append(np.copy(self.obs))
|
mb_obs.append(np.copy(self.obs))
|
||||||
mb_dones.append(self.dones)
|
mb_dones.append(self.dones)
|
||||||
|
|
||||||
enc_obs = np.asarray(enc_obs, dtype=np.uint8).swapaxes(1, 0)
|
enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
|
||||||
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0)
|
mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
|
||||||
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
|
mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)
|
||||||
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
||||||
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
|
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
|
||||||
|
|
||||||
|
@@ -13,6 +13,7 @@ common_kwargs = dict(
|
|||||||
|
|
||||||
learn_kwargs = {
|
learn_kwargs = {
|
||||||
'a2c' : dict(nsteps=32, value_network='copy', lr=0.05),
|
'a2c' : dict(nsteps=32, value_network='copy', lr=0.05),
|
||||||
|
'acer': dict(value_network='copy'),
|
||||||
'acktr': dict(nsteps=32, value_network='copy', is_async=False),
|
'acktr': dict(nsteps=32, value_network='copy', is_async=False),
|
||||||
'deepq': dict(total_timesteps=20000),
|
'deepq': dict(total_timesteps=20000),
|
||||||
'ppo2': dict(value_network='copy'),
|
'ppo2': dict(value_network='copy'),
|
||||||
@@ -40,4 +41,4 @@ def test_cartpole(alg):
|
|||||||
reward_per_episode_test(env_fn, learn_fn, 100)
|
reward_per_episode_test(env_fn, learn_fn, 100)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_cartpole('deepq')
|
test_cartpole('acer')
|
||||||
|
@@ -20,8 +20,8 @@ learn_kwargs = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
algos_disc = ['a2c', 'deepq', 'ppo2', 'trpo_mpi']
|
algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
|
||||||
algos_cont = ['a2c', 'ddpg', 'ppo2', 'trpo_mpi']
|
algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize("alg", algos_disc)
|
@pytest.mark.parametrize("alg", algos_disc)
|
||||||
|
@@ -17,8 +17,7 @@ common_kwargs = {
|
|||||||
|
|
||||||
learn_args = {
|
learn_args = {
|
||||||
'a2c': dict(total_timesteps=50000),
|
'a2c': dict(total_timesteps=50000),
|
||||||
# TODO need to resolve inference (step) API differences for acer; also slow
|
'acer': dict(total_timesteps=20000),
|
||||||
# 'acer': dict(seed=0, total_timesteps=1000),
|
|
||||||
'deepq': dict(total_timesteps=5000),
|
'deepq': dict(total_timesteps=5000),
|
||||||
'acktr': dict(total_timesteps=30000),
|
'acktr': dict(total_timesteps=30000),
|
||||||
'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
|
'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
|
||||||
@@ -47,4 +46,4 @@ def test_mnist(alg):
|
|||||||
simple_test(env_fn, learn_fn, 0.6)
|
simple_test(env_fn, learn_fn, 0.6)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_mnist('deepq')
|
test_mnist('acer')
|
||||||
|
@@ -17,6 +17,7 @@ learn_kwargs = {
|
|||||||
'deepq': {},
|
'deepq': {},
|
||||||
'a2c': {},
|
'a2c': {},
|
||||||
'acktr': {},
|
'acktr': {},
|
||||||
|
'acer': {},
|
||||||
'ppo2': {'nminibatches': 1, 'nsteps': 10},
|
'ppo2': {'nminibatches': 1, 'nsteps': 10},
|
||||||
'trpo_mpi': {},
|
'trpo_mpi': {},
|
||||||
}
|
}
|
||||||
@@ -37,7 +38,7 @@ def test_serialization(learn_fn, network_fn):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
|
|
||||||
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']:
|
if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
|
||||||
# TODO make acktr work with recurrent policies
|
# TODO make acktr work with recurrent policies
|
||||||
# and test
|
# and test
|
||||||
# github issue: https://github.com/openai/baselines/issues/660
|
# github issue: https://github.com/openai/baselines/issues/660
|
||||||
|
@@ -91,9 +91,7 @@ def build_env(args):
|
|||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
|
|
||||||
if env_type in {'atari', 'retro'}:
|
if env_type in {'atari', 'retro'}:
|
||||||
if alg == 'acer':
|
if alg == 'deepq':
|
||||||
env = make_vec_env(env_id, env_type, nenv, seed)
|
|
||||||
elif alg == 'deepq':
|
|
||||||
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
|
env = make_env(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
|
||||||
elif alg == 'trpo_mpi':
|
elif alg == 'trpo_mpi':
|
||||||
env = make_env(env_id, env_type, seed=seed)
|
env = make_env(env_id, env_type, seed=seed)
|
||||||
|
Reference in New Issue
Block a user