using registry of algorithms

This commit is contained in:
Peter Zhokhov
2018-10-22 17:01:49 -07:00
parent 01884bb0eb
commit 0c9b236475
16 changed files with 187 additions and 106 deletions

View File

@@ -2,13 +2,12 @@ import time
import functools import functools
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance from baselines.common import set_global_seeds, explained_variance
from baselines.common import tf_util from baselines.common import tf_util
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.a2c.utils import Scheduler, find_trainable_variables from baselines.a2c.utils import Scheduler, find_trainable_variables
from baselines.a2c.runner import Runner from baselines.a2c.runner import Runner
@@ -114,6 +113,7 @@ class Model(object):
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
@registry.register('a2c')
def learn( def learn(
network, network,
env, env,

View File

@@ -2,11 +2,12 @@ import time
import functools import functools
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables from baselines.common.tf_util import get_session, save_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import batch_to_seq, seq_to_batch
from baselines.a2c.utils import cat_entropy_softmax from baselines.a2c.utils import cat_entropy_softmax
@@ -55,8 +56,7 @@ def q_retrace(R, D, q_i, v, rho_i, nenvs, nsteps, gamma):
# return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio)) # return tf.minimum(1 + eps_clip, tf.maximum(1 - eps_clip, ratio))
class Model(object): class Model(object):
def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, nstack, num_procs, def __init__(self, policy, ob_space, ac_space, nenvs, nsteps, ent_coef, q_coef, gamma, max_grad_norm, lr,
ent_coef, q_coef, gamma, max_grad_norm, lr,
rprop_alpha, rprop_epsilon, total_timesteps, lrschedule, rprop_alpha, rprop_epsilon, total_timesteps, lrschedule,
c, trust_region, alpha, delta): c, trust_region, alpha, delta):
@@ -71,8 +71,8 @@ class Model(object):
LR = tf.placeholder(tf.float32, []) LR = tf.placeholder(tf.float32, [])
eps = 1e-6 eps = 1e-6
step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,)) step_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs,) + ob_space.shape)
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape[:-1] + (ob_space.shape[-1] * nstack,)) train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE): with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess) step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess)
@@ -247,6 +247,7 @@ class Acer():
# get obs, actions, rewards, mus, dones from buffer. # get obs, actions, rewards, mus, dones from buffer.
obs, actions, rewards, mus, dones, masks = buffer.get() obs, actions, rewards, mus, dones, masks = buffer.get()
# reshape stuff correctly # reshape stuff correctly
obs = obs.reshape(runner.batch_ob_shape) obs = obs.reshape(runner.batch_ob_shape)
actions = actions.reshape([runner.nbatch]) actions = actions.reshape([runner.nbatch])
@@ -269,8 +270,8 @@ class Acer():
logger.record_tabular(name, float(val)) logger.record_tabular(name, float(val))
logger.dump_tabular() logger.dump_tabular()
@registry.register('acer')
def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01, def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=0.5, ent_coef=0.01,
max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99, max_grad_norm=10, lr=7e-4, lrschedule='linear', rprop_epsilon=1e-5, rprop_alpha=0.99, gamma=0.99,
log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0, log_interval=100, buffer_size=50000, replay_ratio=4, replay_start=10000, c=10.0,
trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs): trust_region=True, alpha=0.99, delta=1, load_path=None, **network_kwargs):
@@ -342,21 +343,24 @@ def learn(network, env, seed=None, nsteps=20, nstack=4, total_timesteps=int(80e6
print("Running Acer Simple") print("Running Acer Simple")
print(locals()) print(locals())
set_global_seeds(seed) set_global_seeds(seed)
policy = build_policy(env, network, estimate_q=True, **network_kwargs) if not isinstance(env, VecFrameStack):
env = VecFrameStack(env, 1)
policy = build_policy(env, network, estimate_q=True, **network_kwargs)
nenvs = env.num_envs nenvs = env.num_envs
ob_space = env.observation_space ob_space = env.observation_space
ac_space = env.action_space ac_space = env.action_space
num_procs = len(env.remotes) if hasattr(env, 'remotes') else 1# HACK
model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps, nstack=nstack, nstack = env.nstack
num_procs=num_procs, ent_coef=ent_coef, q_coef=q_coef, gamma=gamma, model = Model(policy=policy, ob_space=ob_space, ac_space=ac_space, nenvs=nenvs, nsteps=nsteps,
ent_coef=ent_coef, q_coef=q_coef, gamma=gamma,
max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon, max_grad_norm=max_grad_norm, lr=lr, rprop_alpha=rprop_alpha, rprop_epsilon=rprop_epsilon,
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta) trust_region=trust_region, alpha=alpha, delta=delta)
runner = Runner(env=env, model=model, nsteps=nsteps, nstack=nstack) runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0: if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, nstack=nstack, size=buffer_size) buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)
else: else:
buffer = None buffer = None
nbatch = nenvs*nsteps nbatch = nenvs*nsteps

View File

@@ -2,11 +2,16 @@ import numpy as np
class Buffer(object): class Buffer(object):
# gets obs, actions, rewards, mu's, (states, masks), dones # gets obs, actions, rewards, mu's, (states, masks), dones
def __init__(self, env, nsteps, nstack, size=50000): def __init__(self, env, nsteps, size=50000):
self.nenv = env.num_envs self.nenv = env.num_envs
self.nsteps = nsteps self.nsteps = nsteps
self.nh, self.nw, self.nc = env.observation_space.shape # self.nh, self.nw, self.nc = env.observation_space.shape
self.nstack = nstack self.obs_shape = env.observation_space.shape
self.obs_dtype = env.observation_space.dtype
self.ac_dtype = env.action_space.dtype
self.nc = self.obs_shape[-1]
self.nstack = env.nstack
self.nc //= self.nstack
self.nbatch = self.nenv * self.nsteps self.nbatch = self.nenv * self.nsteps
self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames self.size = size // (self.nsteps) # Each loc contains nenv * nsteps frames, thus total buffer is nenv * size frames
@@ -33,22 +38,11 @@ class Buffer(object):
# Generate stacked frames # Generate stacked frames
def decode(self, enc_obs, dones): def decode(self, enc_obs, dones):
# enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc] # enc_obs has shape [nenvs, nsteps + nstack, nh, nw, nc]
# dones has shape [nenvs, nsteps, nh, nw, nc] # dones has shape [nenvs, nsteps]
# returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc] # returns stacked obs of shape [nenv, (nsteps + 1), nh, nw, nstack*nc]
nstack, nenv, nsteps, nh, nw, nc = self.nstack, self.nenv, self.nsteps, self.nh, self.nw, self.nc
y = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32) return _stack_obs(enc_obs, dones,
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=np.uint8) nsteps=self.nsteps)
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1,
0) # [nsteps + nstack, nenv, nh, nw, nc]
y[3:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
y[:3] = 1.0
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
for i in range(nstack):
obs[-(i + 1), i:] = x
# obs[:,i:,:,:,-(i+1),:] = x
x = x[:-1] * y
y = y[1:]
return np.reshape(obs[:, 3:].transpose((2, 1, 3, 4, 0, 5)), [nenv, (nsteps + 1), nh, nw, nstack * nc])
def put(self, enc_obs, actions, rewards, mus, dones, masks): def put(self, enc_obs, actions, rewards, mus, dones, masks):
# enc_obs [nenv, (nsteps + nstack), nh, nw, nc] # enc_obs [nenv, (nsteps + nstack), nh, nw, nc]
@@ -56,8 +50,8 @@ class Buffer(object):
# mus [nenv, nsteps, nact] # mus [nenv, nsteps, nact]
if self.enc_obs is None: if self.enc_obs is None:
self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=np.uint8) self.enc_obs = np.empty([self.size] + list(enc_obs.shape), dtype=self.obs_dtype)
self.actions = np.empty([self.size] + list(actions.shape), dtype=np.int32) self.actions = np.empty([self.size] + list(actions.shape), dtype=self.ac_dtype)
self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32) self.rewards = np.empty([self.size] + list(rewards.shape), dtype=np.float32)
self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32) self.mus = np.empty([self.size] + list(mus.shape), dtype=np.float32)
self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool) self.dones = np.empty([self.size] + list(dones.shape), dtype=np.bool)
@@ -101,3 +95,62 @@ class Buffer(object):
mus = take(self.mus) mus = take(self.mus)
masks = take(self.masks) masks = take(self.masks)
return obs, actions, rewards, mus, dones, masks return obs, actions, rewards, mus, dones, masks
def _stack_obs_ref(enc_obs, dones, nsteps):
nenv = enc_obs.shape[0]
nstack = enc_obs.shape[1] - nsteps
nh, nw, nc = enc_obs.shape[2:]
obs_dtype = enc_obs.dtype
obs_shape = (nh, nw, nc*nstack)
mask = np.empty([nsteps + nstack - 1, nenv, 1, 1, 1], dtype=np.float32)
obs = np.zeros([nstack, nsteps + nstack, nenv, nh, nw, nc], dtype=obs_dtype)
x = np.reshape(enc_obs, [nenv, nsteps + nstack, nh, nw, nc]).swapaxes(1, 0) # [nsteps + nstack, nenv, nh, nw, nc]
mask[nstack-1:] = np.reshape(1.0 - dones, [nenv, nsteps, 1, 1, 1]).swapaxes(1, 0) # keep
mask[:nstack-1] = 1.0
# y = np.reshape(1 - dones, [nenvs, nsteps, 1, 1, 1])
for i in range(nstack):
obs[-(i + 1), i:] = x
# obs[:,i:,:,:,-(i+1),:] = x
x = x[:-1] * mask
mask = mask[1:]
return np.reshape(obs[:, (nstack-1):].transpose((2, 1, 3, 4, 0, 5)), (nenv, (nsteps + 1)) + obs_shape)
def _stack_obs(enc_obs, dones, nsteps):
nenv = enc_obs.shape[0]
nstack = enc_obs.shape[1] - nsteps
nc = enc_obs.shape[-1]
obs_ = np.zeros((nenv, nsteps + 1) + enc_obs.shape[2:-1] + (enc_obs.shape[-1] * nstack, ), dtype=enc_obs.dtype)
mask = np.ones((nenv, nsteps+1), dtype=enc_obs.dtype)
mask[:, 1:] = 1.0 - dones
mask = mask.reshape(mask.shape + tuple(np.ones(len(enc_obs.shape)-2, dtype=np.uint8)))
for i in range(nstack-1, -1, -1):
obs_[..., i * nc : (i + 1) * nc] = enc_obs[:, i : i + nsteps + 1, :]
if i < nstack-1:
obs_[..., i * nc : (i + 1) * nc] *= mask
mask[:, 1:, ...] *= mask[:, :-1, ...]
return obs_
def test_stack_obs():
nstack = 7
nenv = 1
nsteps = 5
obs_shape = (2, 3, nstack)
enc_obs_shape = (nenv, nsteps + nstack) + obs_shape[:-1] + (1,)
enc_obs = np.random.random(enc_obs_shape)
dones = np.random.randint(low=0, high=2, size=(nenv, nsteps))
stacked_obs_ref = _stack_obs_ref(enc_obs, dones, nsteps=nsteps)
stacked_obs_test = _stack_obs(enc_obs, dones, nsteps=nsteps)
np.testing.assert_allclose(stacked_obs_ref, stacked_obs_test)

View File

@@ -1,30 +1,31 @@
import numpy as np import numpy as np
from baselines.common.runners import AbstractEnvRunner from baselines.common.runners import AbstractEnvRunner
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from gym import spaces
class Runner(AbstractEnvRunner): class Runner(AbstractEnvRunner):
def __init__(self, env, model, nsteps, nstack): def __init__(self, env, model, nsteps):
super().__init__(env=env, model=model, nsteps=nsteps) super().__init__(env=env, model=model, nsteps=nsteps)
self.nstack = nstack assert isinstance(env.action_space, spaces.Discrete), 'This ACER implementation works only with discrete action spaces!'
nh, nw, nc = env.observation_space.shape assert isinstance(env, VecFrameStack)
self.nc = nc # nc = 1 for atari, but just in case
self.nact = env.action_space.n self.nact = env.action_space.n
nenv = self.nenv nenv = self.nenv
self.nbatch = nenv * nsteps self.nbatch = nenv * nsteps
self.batch_ob_shape = (nenv*(nsteps+1), nh, nw, nc*nstack) self.batch_ob_shape = (nenv*(nsteps+1),) + env.observation_space.shape
self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8)
obs = env.reset() self.obs = env.reset()
self.update_obs(obs) self.obs_dtype = env.observation_space.dtype
self.ac_dtype = env.action_space.dtype
self.nstack = self.env.nstack
self.nc = self.batch_ob_shape[-1] // self.nstack
def update_obs(self, obs, dones=None):
#self.obs = obs
if dones is not None:
self.obs *= (1 - dones.astype(np.uint8))[:, None, None, None]
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
self.obs[:, :, :, -self.nc:] = obs[:, :, :, :]
def run(self): def run(self):
enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps # enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
enc_obs = np.split(self.env.stackedobs, self.env.nstack, axis=-1)
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], [] mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
for _ in range(self.nsteps): for _ in range(self.nsteps):
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones) actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
@@ -36,15 +37,15 @@ class Runner(AbstractEnvRunner):
# states information for statefull models like LSTM # states information for statefull models like LSTM
self.states = states self.states = states
self.dones = dones self.dones = dones
self.update_obs(obs, dones) self.obs = obs
mb_rewards.append(rewards) mb_rewards.append(rewards)
enc_obs.append(obs) enc_obs.append(obs[..., -self.nc:])
mb_obs.append(np.copy(self.obs)) mb_obs.append(np.copy(self.obs))
mb_dones.append(self.dones) mb_dones.append(self.dones)
enc_obs = np.asarray(enc_obs, dtype=np.uint8).swapaxes(1, 0) enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0) mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0) mb_actions = np.asarray(mb_actions, dtype=self.ac_dtype).swapaxes(1, 0)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0) mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)

View File

@@ -2,7 +2,7 @@ import os.path as osp
import time import time
import functools import functools
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from baselines.common import set_global_seeds, explained_variance from baselines.common import set_global_seeds, explained_variance
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
@@ -90,6 +90,7 @@ class Model(object):
self.initial_state = step_model.initial_state self.initial_state = step_model.initial_state
tf.global_variables_initializer().run(session=sess) tf.global_variables_initializer().run(session=sess)
@registry.regsiter('acktr')
def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20, def learn(network, env, seed, total_timesteps=int(40e6), gamma=0.99, log_interval=1, nprocs=32, nsteps=20,
ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5, ent_coef=0.01, vf_coef=0.5, vf_fisher_coef=1.0, lr=0.25, max_grad_norm=0.5,
kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs): kfac_clip=0.001, save_interval=None, lrschedule='linear', load_path=None, is_async=True, **network_kwargs):

View File

@@ -17,8 +17,6 @@ from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common import retro_wrappers from baselines.common import retro_wrappers
import functools
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None): def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
""" """
@@ -31,7 +29,7 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
return lambda: env_thunk( return lambda: env_thunk(
env_id=env_id, env_id=env_id,
env_type=env_type, env_type=env_type,
subrank = rank, subrank=rank,
seed=seed, seed=seed,
reward_scale=reward_scale, reward_scale=reward_scale,
gamestate=gamestate gamestate=gamestate
@@ -44,7 +42,7 @@ def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_ind
return DummyVecEnv([make_thunk(start_index)]) return DummyVecEnv([make_thunk(start_index)])
def env_thunk(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs=None): def env_thunk(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0 mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
if env_type == 'atari': if env_type == 'atari':
env = make_atari(env_id) env = make_atari(env_id)

View File

@@ -13,6 +13,7 @@ common_kwargs = dict(
learn_kwargs = { learn_kwargs = {
'a2c' : dict(nsteps=32, value_network='copy', lr=0.05), 'a2c' : dict(nsteps=32, value_network='copy', lr=0.05),
'acer': dict(value_network='copy'),
'acktr': dict(nsteps=32, value_network='copy', is_async=False), 'acktr': dict(nsteps=32, value_network='copy', is_async=False),
'deepq': dict(total_timesteps=20000), 'deepq': dict(total_timesteps=20000),
'ppo2': dict(value_network='copy'), 'ppo2': dict(value_network='copy'),
@@ -40,4 +41,4 @@ def test_cartpole(alg):
reward_per_episode_test(env_fn, learn_fn, 100) reward_per_episode_test(env_fn, learn_fn, 100)
if __name__ == '__main__': if __name__ == '__main__':
test_cartpole('deepq') test_cartpole('acer')

View File

@@ -20,8 +20,8 @@ learn_kwargs = {
} }
algos_disc = ['a2c', 'deepq', 'ppo2', 'trpo_mpi'] algos_disc = ['a2c', 'acktr', 'deepq', 'ppo2', 'trpo_mpi']
algos_cont = ['a2c', 'ddpg', 'ppo2', 'trpo_mpi'] algos_cont = ['a2c', 'acktr', 'ddpg', 'ppo2', 'trpo_mpi']
@pytest.mark.slow @pytest.mark.slow
@pytest.mark.parametrize("alg", algos_disc) @pytest.mark.parametrize("alg", algos_disc)

View File

@@ -17,8 +17,7 @@ common_kwargs = {
learn_args = { learn_args = {
'a2c': dict(total_timesteps=50000), 'a2c': dict(total_timesteps=50000),
# TODO need to resolve inference (step) API differences for acer; also slow 'acer': dict(total_timesteps=20000),
# 'acer': dict(seed=0, total_timesteps=1000),
'deepq': dict(total_timesteps=5000), 'deepq': dict(total_timesteps=5000),
'acktr': dict(total_timesteps=30000), 'acktr': dict(total_timesteps=30000),
'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0), 'ppo2': dict(total_timesteps=50000, lr=1e-3, nsteps=128, ent_coef=0.0),
@@ -47,4 +46,4 @@ def test_mnist(alg):
simple_test(env_fn, learn_fn, 0.6) simple_test(env_fn, learn_fn, 0.6)
if __name__ == '__main__': if __name__ == '__main__':
test_mnist('deepq') test_mnist('acer')

View File

@@ -17,6 +17,7 @@ learn_kwargs = {
'deepq': {}, 'deepq': {},
'a2c': {}, 'a2c': {},
'acktr': {}, 'acktr': {},
'acer': {},
'ppo2': {'nminibatches': 1, 'nsteps': 10}, 'ppo2': {'nminibatches': 1, 'nsteps': 10},
'trpo_mpi': {}, 'trpo_mpi': {},
} }
@@ -37,7 +38,7 @@ def test_serialization(learn_fn, network_fn):
''' '''
if network_fn.endswith('lstm') and learn_fn in ['acktr', 'trpo_mpi', 'deepq']: if network_fn.endswith('lstm') and learn_fn in ['acer', 'acktr', 'trpo_mpi', 'deepq']:
# TODO make acktr work with recurrent policies # TODO make acktr work with recurrent policies
# and test # and test
# github issue: https://github.com/openai/baselines/issues/660 # github issue: https://github.com/openai/baselines/issues/660

View File

@@ -10,11 +10,11 @@ from baselines.ddpg.noise import AdaptiveParamNoiseSpec, NormalActionNoise, Orns
import baselines.common.tf_util as U import baselines.common.tf_util as U
from baselines import logger from baselines import logger, registry
import numpy as np import numpy as np
from mpi4py import MPI from mpi4py import MPI
@registry.register('ddpg')
def learn(network, env, def learn(network, env,
seed=None, seed=None,
total_timesteps=None, total_timesteps=None,

View File

@@ -8,7 +8,7 @@ import numpy as np
import baselines.common.tf_util as U import baselines.common.tf_util as U
from baselines.common.tf_util import load_variables, save_variables from baselines.common.tf_util import load_variables, save_variables
from baselines import logger from baselines import logger, register
from baselines.common.schedules import LinearSchedule from baselines.common.schedules import LinearSchedule
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -92,6 +92,7 @@ def load_act(path):
return ActWrapper.load_act(path) return ActWrapper.load_act(path)
@register('deepq', supports_vecenvs=False)
def learn(env, def learn(env,
network, network,
seed=None, seed=None,

View File

@@ -4,7 +4,7 @@ import functools
import numpy as np import numpy as np
import os.path as osp import os.path as osp
import tensorflow as tf import tensorflow as tf
from baselines import logger from baselines import logger, registry
from collections import deque from collections import deque
from baselines.common import explained_variance, set_global_seeds from baselines.common import explained_variance, set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
@@ -218,6 +218,7 @@ def constfn(val):
return val return val
return f return f
@registry.register('ppo2')
def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4, def learn(*, network, env, total_timesteps, eval_env = None, seed=None, nsteps=2048, ent_coef=0.0, lr=3e-4,
vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95, vf_coef=0.5, max_grad_norm=0.5, gamma=0.99, lam=0.95,
log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2, log_interval=10, nminibatches=4, noptepochs=4, cliprange=0.2,

21
baselines/registry.py Normal file
View File

@@ -0,0 +1,21 @@
from baselines import logger
registry = {}
def register(name, supports_vecenv=True, **kwargs):
def get_fn_entrypoint(fn):
import inspect
return '.'.join([inspect.getmodule(fn).__name__, fn.__name__])
def _thunk(learn_fn):
old_entry = registry.get(name)
if old_entry is not None:
logger.warn('Re-registering learn function {} (old entrypoint {}, new entrypoint {}) '.format(
name, get_fn_entrypoint(old_entry['fn']), get_fn_entrypoint(learn_fn)))
registry[name] = dict(
fn = learn_fn,
supports_vecenv=supports_vecenv,
**kwargs
)
return learn_fn
return _thunk

View File

@@ -3,17 +3,23 @@ import multiprocessing
import os.path as osp import os.path as osp
import gym import gym
from collections import defaultdict from collections import defaultdict
import tensorflow as tf
import numpy as np import numpy as np
from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, env_thunk from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, env_thunk
from baselines.common.tf_util import get_session from baselines import logger
from baselines import bench, logger from baselines.registry import registry
from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common import atari_wrappers, retro_wrappers
import baselines.a2c.a2c
import baselines.acer.acer
import baselines.acktr.acktr
import baselines.deepq.deepq
import baselines.ddpg.ddpg
import baselines.ppo2.ppo2
import baselines.trpo_mpi.trpo_mpi # noqa: F401
try: try:
from mpi4py import MPI from mpi4py import MPI
@@ -87,34 +93,25 @@ def build_env(args):
if sys.platform == 'darwin': ncpu //= 2 if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu nenv = args.num_env or ncpu
alg = args.alg alg = args.alg
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = args.seed seed = args.seed
env_type, env_id = get_env_type(args.env) env_type, env_id = get_env_type(args.env)
assert alg in registry, 'Unknown algorithm {}'.format(alg)
if env_type in {'atari', 'retro'}: if env_type in {'atari', 'retro'}:
if alg == 'acer': frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed)
elif alg == 'deepq':
env = env_thunk(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': True})
elif alg == 'trpo_mpi':
env = env_thunk(env_id, env_type, seed=seed)
else:
frame_stack_size = 4
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
env = VecFrameStack(env, frame_stack_size)
else: else:
config = tf.ConfigProto(allow_soft_placement=True, frame_stack_size = 1
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
get_session(config=config)
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale) if registry[alg]['supports_vecenv']:
env = make_vec_env(env_id, env_type, nenv, seed, gamestate=args.gamestate, reward_scale=args.reward_scale)
if frame_stack_size > 1:
env = VecFrameStack(env, frame_stack_size)
else:
env = env_thunk(env_id, env_type, seed=seed, wrapper_kwargs={'frame_stack': frame_stack_size > 1})
if env_type == 'mujoco': if env_type == 'mujoco' and registry[alg]['supports_vecenv']:
env = VecNormalize(env) env = VecNormalize(env)
return env return env
@@ -141,19 +138,20 @@ def get_default_network(env_type):
return 'mlp' return 'mlp'
def get_alg_module(alg, submodule=None): def get_alg_module(alg, submodule=None):
submodule = submodule or alg import inspect
try: entry = registry.get(alg)
# first try to import the alg module from baselines assert entry is not None, 'Unregistered algorithm {}'.format(alg)
alg_module = import_module('.'.join(['baselines', alg, submodule])) module = inspect.getmodule(entry['fn']).__name__
except ImportError: if submodule is not None:
# then from rl_algs module = '.'.join([module, submodule])
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule])) return module
return alg_module
def get_learn_function(alg): def get_learn_function(alg):
return get_alg_module(alg).learn entry = registry.get(alg)
assert entry is not None, 'Unregistered algorithm {}'.format(alg)
return entry['fn']
def get_learn_function_defaults(alg, env_type): def get_learn_function_defaults(alg, env_type):
@@ -197,6 +195,7 @@ def main():
rank = MPI.COMM_WORLD.Get_rank() rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args) model, env = train(args, extra_args)
env.close() env.close()
if args.save_path is not None and rank == 0: if args.save_path is not None and rank == 0:

View File

@@ -1,5 +1,5 @@
from baselines.common import explained_variance, zipsame, dataset from baselines.common import explained_variance, zipsame, dataset
from baselines import logger from baselines import logger, registry
import baselines.common.tf_util as U import baselines.common.tf_util as U
import tensorflow as tf, numpy as np import tensorflow as tf, numpy as np
import time import time
@@ -82,6 +82,7 @@ def add_vtarg_and_adv(seg, gamma, lam):
gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam
seg["tdlamret"] = seg["adv"] + seg["vpred"] seg["tdlamret"] = seg["adv"] + seg["vpred"]
@registry.register('trpo_mpi', supports_vecenvs=False)
def learn(*, def learn(*,
network, network,
env, env,