misc changes to vecenvs and run.py for benchmarks (#236)
* misc changes to vecenvs and run.py for benchmarks * dont seed global gen * update more references to assert_venvs_equal
This commit is contained in:
committed by
Peter Zhokhov
parent
0dcaafd717
commit
ecf5394226
@@ -136,6 +136,7 @@ def common_arg_parser():
|
|||||||
"""
|
"""
|
||||||
parser = arg_parser()
|
parser = arg_parser()
|
||||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||||
|
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||||
|
@@ -16,19 +16,18 @@ _NP_TO_CT = {np.float32: ctypes.c_float,
|
|||||||
np.uint8: ctypes.c_char,
|
np.uint8: ctypes.c_char,
|
||||||
np.bool: ctypes.c_bool}
|
np.bool: ctypes.c_bool}
|
||||||
|
|
||||||
ctx = mp.get_context('spawn')
|
|
||||||
|
|
||||||
|
|
||||||
class ShmemVecEnv(VecEnv):
|
class ShmemVecEnv(VecEnv):
|
||||||
"""
|
"""
|
||||||
Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
|
Optimized version of SubprocVecEnv that uses shared variables to communicate observations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, env_fns, spaces=None):
|
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||||
"""
|
"""
|
||||||
If you don't specify observation_space, we'll have to create a dummy
|
If you don't specify observation_space, we'll have to create a dummy
|
||||||
environment to get it.
|
environment to get it.
|
||||||
"""
|
"""
|
||||||
|
ctx = mp.get_context(context)
|
||||||
if spaces:
|
if spaces:
|
||||||
observation_space, action_space = spaces
|
observation_space, action_space = spaces
|
||||||
else:
|
else:
|
||||||
|
@@ -3,7 +3,6 @@ import multiprocessing as mp
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
|
||||||
|
|
||||||
ctx = mp.get_context('spawn')
|
|
||||||
|
|
||||||
def worker(remote, parent_remote, env_fn_wrapper):
|
def worker(remote, parent_remote, env_fn_wrapper):
|
||||||
parent_remote.close()
|
parent_remote.close()
|
||||||
@@ -39,7 +38,7 @@ class SubprocVecEnv(VecEnv):
|
|||||||
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
|
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
|
||||||
Recommended to use when num_envs > 1 and step() can be a bottleneck.
|
Recommended to use when num_envs > 1 and step() can be a bottleneck.
|
||||||
"""
|
"""
|
||||||
def __init__(self, env_fns, spaces=None):
|
def __init__(self, env_fns, spaces=None, context='spawn'):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
@@ -48,6 +47,7 @@ class SubprocVecEnv(VecEnv):
|
|||||||
self.waiting = False
|
self.waiting = False
|
||||||
self.closed = False
|
self.closed = False
|
||||||
nenvs = len(env_fns)
|
nenvs = len(env_fns)
|
||||||
|
ctx = mp.get_context(context)
|
||||||
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
|
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)])
|
||||||
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
|
||||||
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
|
||||||
|
@@ -11,37 +11,37 @@ from .subproc_vec_env import SubprocVecEnv
|
|||||||
from baselines.common.tests.test_with_mpi import with_mpi
|
from baselines.common.tests.test_with_mpi import with_mpi
|
||||||
|
|
||||||
|
|
||||||
def assert_envs_equal(env1, env2, num_steps):
|
def assert_venvs_equal(venv1, venv2, num_steps):
|
||||||
"""
|
"""
|
||||||
Compare two environments over num_steps steps and make sure
|
Compare two environments over num_steps steps and make sure
|
||||||
that the observations produced by each are the same when given
|
that the observations produced by each are the same when given
|
||||||
the same actions.
|
the same actions.
|
||||||
"""
|
"""
|
||||||
assert env1.num_envs == env2.num_envs
|
assert venv1.num_envs == venv2.num_envs
|
||||||
assert env1.action_space.shape == env2.action_space.shape
|
assert venv1.observation_space.shape == venv2.observation_space.shape
|
||||||
assert env1.action_space.dtype == env2.action_space.dtype
|
assert venv1.observation_space.dtype == venv2.observation_space.dtype
|
||||||
joint_shape = (env1.num_envs,) + env1.action_space.shape
|
assert venv1.action_space.shape == venv2.action_space.shape
|
||||||
|
assert venv1.action_space.dtype == venv2.action_space.dtype
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obs1, obs2 = env1.reset(), env2.reset()
|
obs1, obs2 = venv1.reset(), venv2.reset()
|
||||||
assert np.array(obs1).shape == np.array(obs2).shape
|
assert np.array(obs1).shape == np.array(obs2).shape
|
||||||
assert np.array(obs1).shape == joint_shape
|
assert np.array(obs1).shape == (venv1.num_envs,) + venv1.observation_space.shape
|
||||||
assert np.allclose(obs1, obs2)
|
assert np.allclose(obs1, obs2)
|
||||||
np.random.seed(1337)
|
venv1.action_space.seed(1337)
|
||||||
for _ in range(num_steps):
|
for _ in range(num_steps):
|
||||||
actions = np.array(np.random.randint(0, 0x100, size=joint_shape),
|
actions = np.array([venv1.action_space.sample() for _ in range(venv1.num_envs)])
|
||||||
dtype=env1.action_space.dtype)
|
for venv in [venv1, venv2]:
|
||||||
for env in [env1, env2]:
|
venv.step_async(actions)
|
||||||
env.step_async(actions)
|
outs1 = venv1.step_wait()
|
||||||
outs1 = env1.step_wait()
|
outs2 = venv2.step_wait()
|
||||||
outs2 = env2.step_wait()
|
|
||||||
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
for out1, out2 in zip(outs1[:3], outs2[:3]):
|
||||||
assert np.array(out1).shape == np.array(out2).shape
|
assert np.array(out1).shape == np.array(out2).shape
|
||||||
assert np.allclose(out1, out2)
|
assert np.allclose(out1, out2)
|
||||||
assert list(outs1[3]) == list(outs2[3])
|
assert list(outs1[3]) == list(outs2[3])
|
||||||
finally:
|
finally:
|
||||||
env1.close()
|
venv1.close()
|
||||||
env2.close()
|
venv2.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
@pytest.mark.parametrize('klass', (ShmemVecEnv, SubprocVecEnv))
|
||||||
@@ -64,7 +64,7 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
|
|||||||
fns = [make_fn(i) for i in range(num_envs)]
|
fns = [make_fn(i) for i in range(num_envs)]
|
||||||
env1 = DummyVecEnv(fns)
|
env1 = DummyVecEnv(fns)
|
||||||
env2 = klass(fns)
|
env2 = klass(fns)
|
||||||
assert_envs_equal(env1, env2, num_steps=num_steps)
|
assert_venvs_equal(env1, env2, num_steps=num_steps)
|
||||||
|
|
||||||
|
|
||||||
class SimpleEnv(gym.Env):
|
class SimpleEnv(gym.Env):
|
||||||
|
@@ -6,15 +6,13 @@ from collections import defaultdict
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from baselines.common.vec_env import VecFrameStack, VecNormalize
|
||||||
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
from baselines.common.vec_env.vec_video_recorder import VecVideoRecorder
|
||||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env, make_env
|
||||||
from baselines.common.tf_util import get_session
|
from baselines.common.tf_util import get_session
|
||||||
from baselines import logger
|
from baselines import logger
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mpi4py import MPI
|
from mpi4py import MPI
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -52,7 +50,7 @@ _game_envs['retro'] = {
|
|||||||
|
|
||||||
|
|
||||||
def train(args, extra_args):
|
def train(args, extra_args):
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args)
|
||||||
print('env_type: {}'.format(env_type))
|
print('env_type: {}'.format(env_type))
|
||||||
|
|
||||||
total_timesteps = int(args.num_timesteps)
|
total_timesteps = int(args.num_timesteps)
|
||||||
@@ -91,7 +89,7 @@ def build_env(args):
|
|||||||
alg = args.alg
|
alg = args.alg
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args)
|
||||||
|
|
||||||
if env_type in {'atari', 'retro'}:
|
if env_type in {'atari', 'retro'}:
|
||||||
if alg == 'deepq':
|
if alg == 'deepq':
|
||||||
@@ -119,7 +117,12 @@ def build_env(args):
|
|||||||
return env
|
return env
|
||||||
|
|
||||||
|
|
||||||
def get_env_type(env_id):
|
def get_env_type(args):
|
||||||
|
env_id = args.env
|
||||||
|
|
||||||
|
if args.env_type is not None:
|
||||||
|
return args.env_type, env_id
|
||||||
|
|
||||||
# Re-parse the gym registry, since we could have new envs since last time.
|
# Re-parse the gym registry, since we could have new envs since last time.
|
||||||
for env in gym.envs.registry.all():
|
for env in gym.envs.registry.all():
|
||||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||||
@@ -205,7 +208,6 @@ def main(args):
|
|||||||
rank = MPI.COMM_WORLD.Get_rank()
|
rank = MPI.COMM_WORLD.Get_rank()
|
||||||
|
|
||||||
model, env = train(args, extra_args)
|
model, env = train(args, extra_args)
|
||||||
env.close()
|
|
||||||
|
|
||||||
if args.save_path is not None and rank == 0:
|
if args.save_path is not None and rank == 0:
|
||||||
save_path = osp.expanduser(args.save_path)
|
save_path = osp.expanduser(args.save_path)
|
||||||
@@ -213,23 +215,25 @@ def main(args):
|
|||||||
|
|
||||||
if args.play:
|
if args.play:
|
||||||
logger.log("Running trained model")
|
logger.log("Running trained model")
|
||||||
env = build_env(args)
|
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||||
dones = np.zeros((1,))
|
dones = np.zeros((1,))
|
||||||
|
|
||||||
|
episode_rew = 0
|
||||||
while True:
|
while True:
|
||||||
if state is not None:
|
if state is not None:
|
||||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||||
else:
|
else:
|
||||||
actions, _, _, _ = model.step(obs)
|
actions, _, _, _ = model.step(obs)
|
||||||
|
|
||||||
obs, _, done, _ = env.step(actions)
|
obs, rew, done, _ = env.step(actions)
|
||||||
|
episode_rew += rew[0]
|
||||||
env.render()
|
env.render()
|
||||||
done = done.any() if isinstance(done, np.ndarray) else done
|
done = done.any() if isinstance(done, np.ndarray) else done
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
|
print(f'episode_rew={episode_rew}')
|
||||||
|
episode_rew = 0
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
|
Reference in New Issue
Block a user