From ac2ea4f31f24edd617700bf8f8627ffeb6b193b0 Mon Sep 17 00:00:00 2001 From: 20chase Date: Thu, 25 Jan 2018 22:09:00 +0800 Subject: [PATCH 1/2] fix logger error for MPI Can't run logger.configure() if rank != 0 --- baselines/trpo_mpi/run_mujoco.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/baselines/trpo_mpi/run_mujoco.py b/baselines/trpo_mpi/run_mujoco.py index aae81b0..d8fe53d 100644 --- a/baselines/trpo_mpi/run_mujoco.py +++ b/baselines/trpo_mpi/run_mujoco.py @@ -19,7 +19,10 @@ def train(env_id, num_timesteps, seed): sess.__enter__() rank = MPI.COMM_WORLD.Get_rank() - if rank != 0: + if rank == 0: + logger.configure() + else: + logger.configure(format_strs=[]) logger.set_level(logger.DISABLED) workerseed = seed + 10000 * MPI.COMM_WORLD.Get_rank() set_global_seeds(workerseed) @@ -43,7 +46,6 @@ def main(): parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--num-timesteps', type=int, default=int(1e6)) args = parser.parse_args() - logger.configure() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) From 4a778555294d3663c59705cf0c37e4cbdb136ec5 Mon Sep 17 00:00:00 2001 From: 20chase Date: Mon, 29 Jan 2018 16:52:01 +0800 Subject: [PATCH 2/2] using mujoco_arg_parser as args remove origin parser --- baselines/trpo_mpi/run_mujoco.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/baselines/trpo_mpi/run_mujoco.py b/baselines/trpo_mpi/run_mujoco.py index 1e07bcc..220bb91 100644 --- a/baselines/trpo_mpi/run_mujoco.py +++ b/baselines/trpo_mpi/run_mujoco.py @@ -27,16 +27,10 @@ def train(env_id, num_timesteps, seed): env.close() def main(): - - import argparse - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--env', help='environment ID', default='Hopper-v1') - parser.add_argument('--seed', help='RNG seed', type=int, default=0) - parser.add_argument('--num-timesteps', type=int, default=int(1e6)) - args = parser.parse_args() args = mujoco_arg_parser().parse_args() train(args.env, num_timesteps=args.num_timesteps, seed=args.seed) if __name__ == '__main__': main() +