misc changes to vecenvs and run.py for benchmarks (#236)

* misc changes to vecenvs and run.py for benchmarks

* dont seed global gen

* update more references to assert_venvs_equal
This commit is contained in:
Christopher Hesse
2019-02-06 17:06:11 -08:00
committed by Peter Zhokhov
parent 0dcaafd717
commit ecf5394226
6 changed files with 45 additions and 41 deletions

View File

@@ -136,6 +136,7 @@ def common_arg_parser():
""" """
parser = arg_parser() parser = arg_parser()
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2') parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
parser.add_argument('--seed', help='RNG seed', type=int, default=None) parser.add_argument('--seed', help='RNG seed', type=int, default=None)
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2') parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
parser.add_argument('--num_timesteps', type=float, default=1e6), parser.add_argument('--num_timesteps', type=float, default=1e6),

View File

@@ -16,19 +16,18 @@ _NP_TO_CT = {np.float32: ctypes.c_float,
np.uint8: ctypes.c_char, np.uint8: ctypes.c_char,
np.bool: ctypes.c_bool} np.bool: ctypes.c_bool}
ctx = mp.get_context('spawn')
class ShmemVecEnv(VecEnv): class ShmemVecEnv(VecEnv):
""" """
Optimized version of SubprocVecEnv that uses shared variables to communicate observations. Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
""" """
def __init__(self, env_fns, spaces=None): def __init__(self, env_fns, spaces=None, context='spawn'):
""" """
If you don't specify observation_space, we'll have to create a dummy If you don't specify observation_space, we'll have to create a dummy
environment to get it. environment to get it.
""" """
ctx = mp.get_context(context)
if spaces: if spaces:
observation_space, action_space = spaces observation_space, action_space = spaces
else: else:

View File

@@ -3,7 +3,6 @@ import multiprocessing as mp
import numpy as np import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
ctx = mp.get_context('spawn')
def worker(remote, parent_remote, env_fn_wrapper): def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close() parent_remote.close()
@@ -39,7 +38,7 @@ class SubprocVecEnv(VecEnv):
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck. Recommended to use when num_envs > 1 and step() can be a bottleneck.
""" """
def __init__(self, env_fns, spaces=None): def __init__(self, env_fns, spaces=None, context='spawn'):
""" """
Arguments: Arguments:
@@ -48,6 +47,7 @@ class SubprocVecEnv(VecEnv):
self.waiting = False self.waiting = False
self.closed = False self.closed = False
nenvs = len(env_fns) nenvs = len(env_fns)
ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)]) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]

View File

@@ -11,37 +11,37 @@ from .subproc_vec_env import SubprocVecEnv
from baselines.common.tests.test_with_mpi import with_mpi from baselines.common.tests.test_with_mpi import with_mpi
def assert_envs_equal(env1, env2, num_steps): def assert_venvs_equal(venv1, venv2, num_steps):
""" """
Compare two environments over num_steps steps and make sure Compare two environments over num_steps steps and make sure
that the observations produced by each are the same when given that the observations produced by each are the same when given
the same actions. the same actions.
""" """
assert env1.num_envs == env2.num_envs assert venv1.num_envs == venv2.num_envs
assert env1.action_space.shape == env2.action_space.shape assert venv1.observation_space.shape == venv2.observation_space.shape
assert env1.action_space.dtype == env2.action_space.dtype assert venv1.observation_space.dtype == venv2.observation_space.dtype
joint_shape = (env1.num_envs,) + env1.action_space.shape assert venv1.action_space.shape == venv2.action_space.shape
assert venv1.action_space.dtype == venv2.action_space.dtype
try: try:
obs1, obs2 = env1.reset(), env2.reset() obs1, obs2 = venv1.reset(), venv2.reset()
assert np.array(obs1).shape == np.array(obs2).shape assert np.array(obs1).shape == np.array(obs2).shape
assert np.array(obs1).shape == joint_shape assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
assert np.allclose(obs1, obs2) assert np.allclose(obs1, obs2)
np.random.seed(1337) venv1.action_space.seed(1337)
for _ in range(num_steps): for _ in range(num_steps):
actions = np.array(np.random.randint(0, 0x100, size=joint_shape), actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
dtype=env1.action_space.dtype) for venv in [venv1, venv2]:
for env in [env1, env2]: venv.step_async(actions)
env.step_async(actions) outs1 = venv1.step_wait()
outs1 = env1.step_wait() outs2 = venv2.step_wait()
outs2 = env2.step_wait()
for out1, out2 in zip(outs1[:3], outs2[:3]): for out1, out2 in zip(outs1[:3], outs2[:3]):
assert np.array(out1).shape == np.array(out2).shape assert np.array(out1).shape == np.array(out2).shape
assert np.allclose(out1, out2) assert np.allclose(out1, out2)
assert list(outs1[3]) == list(outs2[3]) assert list(outs1[3]) == list(outs2[3])
finally: finally:
env1.close() venv1.close()
env2.close() venv2.close()
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv)) @pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
@@ -64,7 +64,7 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
fns = [make_fn(i) for i in range(num_envs)] fns = [make_fn(i) for i in range(num_envs)]
env1 = DummyVecEnv(fns) env1 = DummyVecEnv(fns)
env2 = klass(fns) env2 = klass(fns)
assert_envs_equal(env1, env2, num_steps=num_steps) assert_venvs_equal(env1, env2, num_steps=num_steps)
class SimpleEnv(gym.Env): class SimpleEnv(gym.Env):

View File

@@ -217,4 +217,4 @@ def clear_mpi_env_vars():
try: try:
yield yield
finally: finally:
os.environ.update(removed_environment) os.environ.update(removed_environment)

View File

@@ -6,15 +6,13 @@ from collections import defaultdict
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from baselines.common.vec_env import VecFrameStack, VecNormalize
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
from baselines.common.tf_util import get_session from baselines.common.tf_util import get_session
from baselines import logger from baselines import logger
from importlib import import_module from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize
try: try:
from mpi4py import MPI from mpi4py import MPI
except ImportError: except ImportError:
@@ -52,7 +50,7 @@ _game_envs['retro'] = {
def train(args, extra_args): def train(args, extra_args):
env_type, env_id = get_env_type(args.env) env_type, env_id = get_env_type(args)
print('env_type: {}'.format(env_type)) print('env_type: {}'.format(env_type))
total_timesteps = int(args.num_timesteps) total_timesteps = int(args.num_timesteps)
@@ -91,7 +89,7 @@ def build_env(args):
alg = args.alg alg = args.alg
seed = args.seed seed = args.seed
env_type, env_id = get_env_type(args.env) env_type, env_id = get_env_type(args)
if env_type in {'atari', 'retro'}: if env_type in {'atari', 'retro'}:
if alg == 'deepq': if alg == 'deepq':
@@ -104,22 +102,27 @@ def build_env(args):
env = VecFrameStack(env, frame_stack_size) env = VecFrameStack(env, frame_stack_size)
else: else:
config = tf.ConfigProto(allow_soft_placement=True, config = tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1, intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1) inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
get_session(config=config) get_session(config=config)
flatten_dict_observations = alg not in {'her'} flatten_dict_observations = alg not in {'her'}
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations) env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
if env_type == 'mujoco': if env_type == 'mujoco':
env = VecNormalize(env) env = VecNormalize(env)
return env return env
def get_env_type(env_id): def get_env_type(args):
env_id = args.env
if args.env_type is not None:
return args.env_type, env_id
# Re-parse the gym registry, since we could have new envs since last time. # Re-parse the gym registry, since we could have new envs since last time.
for env in gym.envs.registry.all(): for env in gym.envs.registry.all():
env_type = env._entry_point.split(':')[0].split('.')[-1] env_type = env._entry_point.split(':')[0].split('.')[-1]
@@ -205,7 +208,6 @@ def main(args):
rank = MPI.COMM_WORLD.Get_rank() rank = MPI.COMM_WORLD.Get_rank()
model, env = train(args, extra_args) model, env = train(args, extra_args)
env.close()
if args.save_path is not None and rank == 0: if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path) save_path = osp.expanduser(args.save_path)
@@ -213,26 +215,28 @@ def main(args):
if args.play: if args.play:
logger.log("Running trained model") logger.log("Running trained model")
env = build_env(args)
obs = env.reset() obs = env.reset()
state = model.initial_state if hasattr(model, 'initial_state') else None state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,)) dones = np.zeros((1,))
episode_rew = 0
while True: while True:
if state is not None: if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones) actions, _, state, _ = model.step(obs,S=state, M=dones)
else: else:
actions, _, _, _ = model.step(obs) actions, _, _, _ = model.step(obs)
obs, _, done, _ = env.step(actions) obs, rew, done, _ = env.step(actions)
episode_rew += rew[0]
env.render() env.render()
done = done.any() if isinstance(done, np.ndarray) else done done = done.any() if isinstance(done, np.ndarray) else done
if done: if done:
print(f'episode_rew={episode_rew}')
episode_rew = 0
obs = env.reset() obs = env.reset()
env.close() env.close()
return model return model