Fix result_plotters in vectorized mujoco environments (#533)
* I investigated a bit about running a training in a vectorized monitored mujoco env and found out that the 0.monitor.csv file could not be plotted using baselines.results_plotter.py functions. Moreover the seed is the same in every parallel environments due to the particular behaviour of lambda. this fixes both issues without breaking the function in other files (baselines.acktr.run_mujoco still works) * unifies make_atari_env and make_mujoco_env * redefine make_mujoco_env because of run_mujoco in acktr not compatible with DummyVecEnv and SubprocVecEnv * fix if else * Update run.py
This commit is contained in:
@@ -15,22 +15,28 @@ from baselines.bench import Monitor
|
|||||||
from baselines.common import set_global_seeds
|
from baselines.common import set_global_seeds
|
||||||
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
from baselines.common.atari_wrappers import make_atari, wrap_deepmind
|
||||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||||
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
|
from baselines.common.retro_wrappers import RewardScaler
|
||||||
|
|
||||||
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0):
|
|
||||||
|
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0):
|
||||||
"""
|
"""
|
||||||
Create a wrapped, monitored SubprocVecEnv for Atari.
|
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||||
"""
|
"""
|
||||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
def make_env(rank): # pylint: disable=C0111
|
def make_env(rank): # pylint: disable=C0111
|
||||||
def _thunk():
|
def _thunk():
|
||||||
env = make_atari(env_id)
|
env = make_atari(env_id) if env_type == 'atari' else gym.make(env_id)
|
||||||
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
env.seed(seed + 10000*mpi_rank + rank if seed is not None else None)
|
||||||
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(rank)))
|
||||||
return wrap_deepmind(env, **wrapper_kwargs)
|
if env_type == 'atari': return wrap_deepmind(env, **wrapper_kwargs)
|
||||||
|
elif reward_scale != 1: return RewardScaler(env, reward_scale)
|
||||||
|
else: return env
|
||||||
return _thunk
|
return _thunk
|
||||||
set_global_seeds(seed)
|
set_global_seeds(seed)
|
||||||
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
if num_env > 1: return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
|
||||||
|
else: return DummyVecEnv([make_env(start_index)])
|
||||||
|
|
||||||
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
||||||
"""
|
"""
|
||||||
@@ -43,11 +49,9 @@ def make_mujoco_env(env_id, seed, reward_scale=1.0):
|
|||||||
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
logger_path = None if logger.get_dir() is None else os.path.join(logger.get_dir(), str(rank))
|
||||||
env = Monitor(env, logger_path, allow_early_resets=True)
|
env = Monitor(env, logger_path, allow_early_resets=True)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
|
|
||||||
if reward_scale != 1.0:
|
if reward_scale != 1.0:
|
||||||
from baselines.common.retro_wrappers import RewardScaler
|
from baselines.common.retro_wrappers import RewardScaler
|
||||||
env = RewardScaler(env, reward_scale)
|
env = RewardScaler(env, reward_scale)
|
||||||
|
|
||||||
return env
|
return env
|
||||||
|
|
||||||
def make_robotics_env(env_id, seed, rank=0):
|
def make_robotics_env(env_id, seed, rank=0):
|
||||||
@@ -89,7 +93,7 @@ def common_arg_parser():
|
|||||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||||
parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
|
parser.add_argument('--network', help='network type (mlp, cnn, lstm, cnn_lstm, conv_only)', default=None)
|
||||||
parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
|
parser.add_argument('--gamestate', help='game state to load (so far only used in retro games)', default=None)
|
||||||
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
|
parser.add_argument('--num_env', help='Number of environment copies being run in parallel. When not specified, set to number of cpus for Atari, and to 1 for Mujoco', default=None, type=int)
|
||||||
@@ -122,6 +126,3 @@ def parse_unknown_args(args):
|
|||||||
retval[key] = value
|
retval[key] = value
|
||||||
|
|
||||||
return retval
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@@ -7,14 +7,13 @@ import tensorflow as tf
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
|
||||||
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
|
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
|
||||||
from baselines.common.tf_util import get_session
|
from baselines.common.tf_util import get_session
|
||||||
from baselines import bench, logger
|
from baselines import bench, logger
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from baselines.common.vec_env.vec_normalize import VecNormalize
|
from baselines.common.vec_env.vec_normalize import VecNormalize
|
||||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||||
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
|
||||||
from baselines.common import atari_wrappers, retro_wrappers
|
from baselines.common import atari_wrappers, retro_wrappers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -28,9 +27,9 @@ for env in gym.envs.registry.all():
|
|||||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||||
_game_envs[env_type].add(env.id)
|
_game_envs[env_type].add(env.id)
|
||||||
|
|
||||||
# reading benchmark names directly from retro requires
|
# reading benchmark names directly from retro requires
|
||||||
# importing retro here, and for some reason that crashes tensorflow
|
# importing retro here, and for some reason that crashes tensorflow
|
||||||
# in ubuntu
|
# in ubuntu
|
||||||
_game_envs['retro'] = set([
|
_game_envs['retro'] = set([
|
||||||
'BubbleBobble-Nes',
|
'BubbleBobble-Nes',
|
||||||
'SuperMarioBros-Nes',
|
'SuperMarioBros-Nes',
|
||||||
@@ -45,7 +44,7 @@ _game_envs['retro'] = set([
|
|||||||
|
|
||||||
def train(args, extra_args):
|
def train(args, extra_args):
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
|
|
||||||
total_timesteps = int(args.num_timesteps)
|
total_timesteps = int(args.num_timesteps)
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|
||||||
@@ -60,13 +59,13 @@ def train(args, extra_args):
|
|||||||
else:
|
else:
|
||||||
if alg_kwargs.get('network') is None:
|
if alg_kwargs.get('network') is None:
|
||||||
alg_kwargs['network'] = get_default_network(env_type)
|
alg_kwargs['network'] = get_default_network(env_type)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
|
||||||
|
|
||||||
model = learn(
|
model = learn(
|
||||||
env=env,
|
env=env,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
total_timesteps=total_timesteps,
|
total_timesteps=total_timesteps,
|
||||||
**alg_kwargs
|
**alg_kwargs
|
||||||
@@ -77,28 +76,28 @@ def train(args, extra_args):
|
|||||||
|
|
||||||
def build_env(args):
|
def build_env(args):
|
||||||
ncpu = multiprocessing.cpu_count()
|
ncpu = multiprocessing.cpu_count()
|
||||||
if sys.platform == 'darwin': ncpu //= 2
|
if sys.platform == 'darwin': ncpu /= 2
|
||||||
nenv = args.num_env or ncpu
|
nenv = args.num_env or ncpu
|
||||||
alg = args.alg
|
alg = args.alg
|
||||||
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
if env_type == 'mujoco':
|
if env_type == 'mujoco':
|
||||||
get_session(tf.ConfigProto(allow_soft_placement=True,
|
get_session(tf.ConfigProto(allow_soft_placement=True,
|
||||||
intra_op_parallelism_threads=1,
|
intra_op_parallelism_threads=1,
|
||||||
inter_op_parallelism_threads=1))
|
inter_op_parallelism_threads=1))
|
||||||
|
|
||||||
if args.num_env:
|
if args.num_env:
|
||||||
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
|
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
|
||||||
else:
|
else:
|
||||||
env = DummyVecEnv([lambda: make_mujoco_env(env_id, seed, args.reward_scale)])
|
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
|
||||||
|
|
||||||
env = VecNormalize(env)
|
env = VecNormalize(env)
|
||||||
|
|
||||||
elif env_type == 'atari':
|
elif env_type == 'atari':
|
||||||
if alg == 'acer':
|
if alg == 'acer':
|
||||||
env = make_atari_env(env_id, nenv, seed)
|
env = make_vec_env(env_id, env_type, nenv, seed)
|
||||||
elif alg == 'deepq':
|
elif alg == 'deepq':
|
||||||
env = atari_wrappers.make_atari(env_id)
|
env = atari_wrappers.make_atari(env_id)
|
||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
@@ -113,7 +112,7 @@ def build_env(args):
|
|||||||
env.seed(seed)
|
env.seed(seed)
|
||||||
else:
|
else:
|
||||||
frame_stack_size = 4
|
frame_stack_size = 4
|
||||||
env = VecFrameStack(make_atari_env(env_id, nenv, seed), frame_stack_size)
|
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
|
||||||
|
|
||||||
elif env_type == 'retro':
|
elif env_type == 'retro':
|
||||||
import retro
|
import retro
|
||||||
@@ -122,14 +121,14 @@ def build_env(args):
|
|||||||
env.seed(args.seed)
|
env.seed(args.seed)
|
||||||
env = bench.Monitor(env, logger.get_dir())
|
env = bench.Monitor(env, logger.get_dir())
|
||||||
env = retro_wrappers.wrap_deepmind_retro(env)
|
env = retro_wrappers.wrap_deepmind_retro(env)
|
||||||
|
|
||||||
elif env_type == 'classic_control':
|
elif env_type == 'classic_control':
|
||||||
def make_env():
|
def make_env():
|
||||||
e = gym.make(env_id)
|
e = gym.make(env_id)
|
||||||
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
|
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
|
||||||
e.seed(seed)
|
e.seed(seed)
|
||||||
return e
|
return e
|
||||||
|
|
||||||
env = DummyVecEnv([make_env])
|
env = DummyVecEnv([make_env])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -147,7 +146,7 @@ def get_env_type(env_id):
|
|||||||
for g, e in _game_envs.items():
|
for g, e in _game_envs.items():
|
||||||
if env_id in e:
|
if env_id in e:
|
||||||
env_type = g
|
env_type = g
|
||||||
break
|
break
|
||||||
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
|
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
|
||||||
|
|
||||||
return env_type, env_id
|
return env_type, env_id
|
||||||
@@ -159,7 +158,7 @@ def get_default_network(env_type):
|
|||||||
return 'cnn'
|
return 'cnn'
|
||||||
|
|
||||||
raise ValueError('Unknown env_type {}'.format(env_type))
|
raise ValueError('Unknown env_type {}'.format(env_type))
|
||||||
|
|
||||||
def get_alg_module(alg, submodule=None):
|
def get_alg_module(alg, submodule=None):
|
||||||
submodule = submodule or alg
|
submodule = submodule or alg
|
||||||
try:
|
try:
|
||||||
@@ -168,9 +167,9 @@ def get_alg_module(alg, submodule=None):
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
# then from rl_algs
|
# then from rl_algs
|
||||||
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
|
||||||
|
|
||||||
return alg_module
|
return alg_module
|
||||||
|
|
||||||
|
|
||||||
def get_learn_function(alg):
|
def get_learn_function(alg):
|
||||||
return get_alg_module(alg).learn
|
return get_alg_module(alg).learn
|
||||||
@@ -180,29 +179,29 @@ def get_learn_function_defaults(alg, env_type):
|
|||||||
alg_defaults = get_alg_module(alg, 'defaults')
|
alg_defaults = get_alg_module(alg, 'defaults')
|
||||||
kwargs = getattr(alg_defaults, env_type)()
|
kwargs = getattr(alg_defaults, env_type)()
|
||||||
except (ImportError, AttributeError):
|
except (ImportError, AttributeError):
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def parse(v):
|
def parse(v):
|
||||||
'''
|
'''
|
||||||
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
convert value of a command-line arg to a python object if possible, othewise, keep as string
|
||||||
'''
|
'''
|
||||||
|
|
||||||
assert isinstance(v, str)
|
assert isinstance(v, str)
|
||||||
try:
|
try:
|
||||||
return eval(v)
|
return eval(v)
|
||||||
except (NameError, SyntaxError):
|
except (NameError, SyntaxError):
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# configure logger, disable logging in child MPI processes (with rank > 0)
|
# configure logger, disable logging in child MPI processes (with rank > 0)
|
||||||
|
|
||||||
arg_parser = common_arg_parser()
|
arg_parser = common_arg_parser()
|
||||||
args, unknown_args = arg_parser.parse_known_args()
|
args, unknown_args = arg_parser.parse_known_args()
|
||||||
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
|
||||||
|
|
||||||
|
|
||||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||||
rank = 0
|
rank = 0
|
||||||
logger.configure()
|
logger.configure()
|
||||||
@@ -215,7 +214,7 @@ def main():
|
|||||||
if args.save_path is not None and rank == 0:
|
if args.save_path is not None and rank == 0:
|
||||||
save_path = osp.expanduser(args.save_path)
|
save_path = osp.expanduser(args.save_path)
|
||||||
model.save(save_path)
|
model.save(save_path)
|
||||||
|
|
||||||
|
|
||||||
if args.play:
|
if args.play:
|
||||||
logger.log("Running trained model")
|
logger.log("Running trained model")
|
||||||
@@ -229,7 +228,7 @@ def main():
|
|||||||
|
|
||||||
if done:
|
if done:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Reference in New Issue
Block a user