Fix result_plotters in vectorized mujoco environments (#533)

* I investigated a bit about running a training in a vectorized monitored mujoco env and found out that the 0.monitor.csv file could not be plotted using baselines.results_plotter.py functions. Moreover the seed is the same in every parallel environments due to the particular behaviour of lambda. this fixes both issues without breaking the function in other files (baselines.acktr.run_mujoco still works)

* unifies make_atari_env and make_mujoco_env

* redefine make_mujoco_env because of run_mujoco in acktr not compatible with DummyVecEnv and SubprocVecEnv

* fix if else

* Update run.py
This commit is contained in:
Damien Lancry
2018-08-29 01:48:56 +01:00
committed by pzhokhov
parent 0961f5dd94
commit bdd4d385a6
2 changed files with 44 additions and 44 deletions

View File

@@ -7,14 +7,13 @@ import tensorflow as tf
import numpy as np
from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_mujoco_env, make_atari_env
from baselines.common.cmd_util import common_arg_parser, parse_unknown_args, make_vec_env
from baselines.common.tf_util import get_session
from baselines import bench, logger
from importlib import import_module
from baselines.common.vec_env.vec_normalize import VecNormalize
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
from baselines.common import atari_wrappers, retro_wrappers
try:
@@ -28,9 +27,9 @@ for env in gym.envs.registry.all():
env_type = env._entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id)
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
# reading benchmark names directly from retro requires
# importing retro here, and for some reason that crashes tensorflow
# in ubuntu
_game_envs['retro'] = set([
'BubbleBobble-Nes',
'SuperMarioBros-Nes',
@@ -45,7 +44,7 @@ _game_envs['retro'] = set([
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
total_timesteps = int(args.num_timesteps)
seed = args.seed
@@ -60,13 +59,13 @@ def train(args, extra_args):
else:
if alg_kwargs.get('network') is None:
alg_kwargs['network'] = get_default_network(env_type)
print('Training {} on {}:{} with arguments \n{}'.format(args.alg, env_type, env_id, alg_kwargs))
model = learn(
env=env,
env=env,
seed=seed,
total_timesteps=total_timesteps,
**alg_kwargs
@@ -77,28 +76,28 @@ def train(args, extra_args):
def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
if sys.platform == 'darwin': ncpu /= 2
nenv = args.num_env or ncpu
alg = args.alg
rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
seed = args.seed
seed = args.seed
env_type, env_id = get_env_type(args.env)
if env_type == 'mujoco':
get_session(tf.ConfigProto(allow_soft_placement=True,
intra_op_parallelism_threads=1,
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1))
if args.num_env:
env = SubprocVecEnv([lambda: make_mujoco_env(env_id, seed + i if seed is not None else None, args.reward_scale) for i in range(args.num_env)])
env = make_vec_env(env_id, env_type, nenv, seed, reward_scale=args.reward_scale)
else:
env = DummyVecEnv([lambda: make_mujoco_env(env_id, seed, args.reward_scale)])
env = make_vec_env(env_id, env_type, 1, seed, reward_scale=args.reward_scale)
env = VecNormalize(env)
elif env_type == 'atari':
if alg == 'acer':
env = make_atari_env(env_id, nenv, seed)
env = make_vec_env(env_id, env_type, nenv, seed)
elif alg == 'deepq':
env = atari_wrappers.make_atari(env_id)
env.seed(seed)
@@ -113,7 +112,7 @@ def build_env(args):
env.seed(seed)
else:
frame_stack_size = 4
env = VecFrameStack(make_atari_env(env_id, nenv, seed), frame_stack_size)
env = VecFrameStack(make_vec_env(env_id, env_type, nenv, seed), frame_stack_size)
elif env_type == 'retro':
import retro
@@ -122,14 +121,14 @@ def build_env(args):
env.seed(args.seed)
env = bench.Monitor(env, logger.get_dir())
env = retro_wrappers.wrap_deepmind_retro(env)
elif env_type == 'classic_control':
def make_env():
e = gym.make(env_id)
e = bench.Monitor(e, logger.get_dir(), allow_early_resets=True)
e.seed(seed)
return e
env = DummyVecEnv([make_env])
else:
@@ -147,7 +146,7 @@ def get_env_type(env_id):
for g, e in _game_envs.items():
if env_id in e:
env_type = g
break
break
assert env_type is not None, 'env_id {} is not recognized in env types'.format(env_id, _game_envs.keys())
return env_type, env_id
@@ -159,7 +158,7 @@ def get_default_network(env_type):
return 'cnn'
raise ValueError('Unknown env_type {}'.format(env_type))
def get_alg_module(alg, submodule=None):
submodule = submodule or alg
try:
@@ -168,9 +167,9 @@ def get_alg_module(alg, submodule=None):
except ImportError:
# then from rl_algs
alg_module = import_module('.'.join(['rl_' + 'algs', alg, submodule]))
return alg_module
def get_learn_function(alg):
return get_alg_module(alg).learn
@@ -180,29 +179,29 @@ def get_learn_function_defaults(alg, env_type):
alg_defaults = get_alg_module(alg, 'defaults')
kwargs = getattr(alg_defaults, env_type)()
except (ImportError, AttributeError):
kwargs = {}
kwargs = {}
return kwargs
def parse(v):
def parse(v):
'''
convert value of a command-line arg to a python object if possible, othewise, keep as string
'''
assert isinstance(v, str)
try:
return eval(v)
except (NameError, SyntaxError):
return eval(v)
except (NameError, SyntaxError):
return v
def main():
# configure logger, disable logging in child MPI processes (with rank > 0)
# configure logger, disable logging in child MPI processes (with rank > 0)
arg_parser = common_arg_parser()
args, unknown_args = arg_parser.parse_known_args()
extra_args = {k: parse(v) for k,v in parse_unknown_args(unknown_args).items()}
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
rank = 0
logger.configure()
@@ -215,7 +214,7 @@ def main():
if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)
if args.play:
logger.log("Running trained model")
@@ -229,7 +228,7 @@ def main():
if done:
obs = env.reset()
if __name__ == '__main__':