misc changes to vecenvs and run.py for benchmarks (#236)
* misc changes to vecenvs and run.py for benchmarks * dont seed global gen * update more references to assert_venvs_equal
This commit is contained in:
committed by
Peter Zhokhov
parent
0dcaafd717
commit
ecf5394226
@@ -136,6 +136,7 @@ def common_arg_parser():
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||
|
@@ -16,19 +16,18 @@ _NP_TO_CT = {np.float32: ctypes.c_float,
|
||||
np.uint8: ctypes.c_char,
|
||||
np.bool: ctypes.c_bool}
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
|
||||
|
||||
class ShmemVecEnv(VecEnv):
|
||||
"""
|
||||
Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
|
||||
"""
|
||||
|
||||
def __init__(self, env_fns, spaces=None):
|
||||
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||
"""
|
||||
If you don't specify observation_space, we'll have to create a dummy
|
||||
environment to get it.
|
||||
"""
|
||||
ctx = mp.get_context(context)
|
||||
if spaces:
|
||||
observation_space, action_space = spaces
|
||||
else:
|
||||
|
@@ -3,7 +3,6 @@ import multiprocessing as mp
|
||||
import numpy as np
|
||||
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||
|
||||
ctx = mp.get_context('spawn')
|
||||
|
||||
def worker(remote, parent_remote, env_fn_wrapper):
|
||||
parent_remote.close()
|
||||
@@ -39,7 +38,7 @@ class SubprocVecEnv(VecEnv):
|
||||
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
|
||||
Recommended to use when num_envs > 1 and step() can be a bottleneck.
|
||||
"""
|
||||
def __init__(self, env_fns, spaces=None):
|
||||
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||
"""
|
||||
Arguments:
|
||||
|
||||
@@ -48,6 +47,7 @@ class SubprocVecEnv(VecEnv):
|
||||
self.waiting = False
|
||||
self.closed = False
|
||||
nenvs = len(env_fns)
|
||||
ctx = mp.get_context(context)
|
||||
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
|
||||
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
||||
|
@@ -11,37 +11,37 @@ from .subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.tests.test_with_mpi import with_mpi
|
||||
|
||||
|
||||
def assert_envs_equal(env1, env2, num_steps):
|
||||
def assert_venvs_equal(venv1, venv2, num_steps):
|
||||
"""
|
||||
Compare two environments over num_steps steps and make sure
|
||||
that the observations produced by each are the same when given
|
||||
the same actions.
|
||||
"""
|
||||
assert env1.num_envs == env2.num_envs
|
||||
assert env1.action_space.shape == env2.action_space.shape
|
||||
assert env1.action_space.dtype == env2.action_space.dtype
|
||||
joint_shape = (env1.num_envs,) + env1.action_space.shape
|
||||
assert venv1.num_envs == venv2.num_envs
|
||||
assert venv1.observation_space.shape == venv2.observation_space.shape
|
||||
assert venv1.observation_space.dtype == venv2.observation_space.dtype
|
||||
assert venv1.action_space.shape == venv2.action_space.shape
|
||||
assert venv1.action_space.dtype == venv2.action_space.dtype
|
||||
|
||||
try:
|
||||
obs1, obs2 = env1.reset(), env2.reset()
|
||||
obs1, obs2 = venv1.reset(), venv2.reset()
|
||||
assert np.array(obs1).shape == np.array(obs2).shape
|
||||
assert np.array(obs1).shape == joint_shape
|
||||
assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
|
||||
assert np.allclose(obs1, obs2)
|
||||
np.random.seed(1337)
|
||||
venv1.action_space.seed(1337)
|
||||
for _ in range(num_steps):
|
||||
actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
|
||||
dtype=env1.action_space.dtype)
|
||||
for env in [env1, env2]:
|
||||
env.step_async(actions)
|
||||
outs1 = env1.step_wait()
|
||||
outs2 = env2.step_wait()
|
||||
actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
|
||||
for venv in [venv1, venv2]:
|
||||
venv.step_async(actions)
|
||||
outs1 = venv1.step_wait()
|
||||
outs2 = venv2.step_wait()
|
||||
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
||||
assert np.array(out1).shape == np.array(out2).shape
|
||||
assert np.allclose(out1, out2)
|
||||
assert list(outs1[3]) == list(outs2[3])
|
||||
finally:
|
||||
env1.close()
|
||||
env2.close()
|
||||
venv1.close()
|
||||
venv2.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
||||
@@ -64,7 +64,7 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
|
||||
fns = [make_fn(i) for i in range(num_envs)]
|
||||
env1 = DummyVecEnv(fns)
|
||||
env2 = klass(fns)
|
||||
assert_envs_equal(env1, env2, num_steps=num_steps)
|
||||
assert_venvs_equal(env1, env2, num_steps=num_steps)
|
||||
|
||||
|
||||
class SimpleEnv(gym.Env):
|
||||
|
@@ -217,4 +217,4 @@ def clear_mpi_env_vars():
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ.update(removed_environment)
|
||||
os.environ.update(removed_environment)
|
||||
|
@@ -6,15 +6,13 @@ from collections import defaultdict
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from baselines.common.vec_env import VecFrameStack, VecNormalize
|
||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||
from baselines.common.tf_util import get_session
|
||||
from baselines import logger
|
||||
from importlib import import_module
|
||||
|
||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
except ImportError:
|
||||
@@ -52,7 +50,7 @@ _game_envs['retro'] = {
|
||||
|
||||
|
||||
def train(args, extra_args):
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
env_type, env_id = get_env_type(args)
|
||||
print('env_type: {}'.format(env_type))
|
||||
|
||||
total_timesteps = int(args.num_timesteps)
|
||||
@@ -91,7 +89,7 @@ def build_env(args):
|
||||
alg = args.alg
|
||||
seed = args.seed
|
||||
|
||||
env_type, env_id = get_env_type(args.env)
|
||||
env_type, env_id = get_env_type(args)
|
||||
|
||||
if env_type in {'atari', 'retro'}:
|
||||
if alg == 'deepq':
|
||||
@@ -104,22 +102,27 @@ def build_env(args):
|
||||
env = VecFrameStack(env, frame_stack_size)
|
||||
|
||||
else:
|
||||
config = tf.ConfigProto(allow_soft_placement=True,
|
||||
config = tf.ConfigProto(allow_soft_placement=True,
|
||||
intra_op_parallelism_threads=1,
|
||||
inter_op_parallelism_threads=1)
|
||||
config.gpu_options.allow_growth = True
|
||||
get_session(config=config)
|
||||
config.gpu_options.allow_growth = True
|
||||
get_session(config=config)
|
||||
|
||||
flatten_dict_observations = alg not in {'her'}
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
|
||||
flatten_dict_observations = alg not in {'her'}
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
|
||||
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def get_env_type(env_id):
|
||||
def get_env_type(args):
|
||||
env_id = args.env
|
||||
|
||||
if args.env_type is not None:
|
||||
return args.env_type, env_id
|
||||
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
@@ -205,7 +208,6 @@ def main(args):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
model, env = train(args, extra_args)
|
||||
env.close()
|
||||
|
||||
if args.save_path is not None and rank == 0:
|
||||
save_path = osp.expanduser(args.save_path)
|
||||
@@ -213,26 +215,28 @@ def main(args):
|
||||
|
||||
if args.play:
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
obs = env.reset()
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
dones = np.zeros((1,))
|
||||
|
||||
episode_rew = 0
|
||||
while True:
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
else:
|
||||
actions, _, _, _ = model.step(obs)
|
||||
|
||||
obs, _, done, _ = env.step(actions)
|
||||
obs, rew, done, _ = env.step(actions)
|
||||
episode_rew += rew[0]
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
|
||||
if done:
|
||||
print(f'episode_rew={episode_rew}')
|
||||
episode_rew = 0
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
env.close()
|
||||
|
||||
return model
|
||||
|
||||
|
Reference in New Issue
Block a user